【transformers.Trainer填坑】在自定义compute_metrics时logits和labels数据维度不一致问题

问题描述

我在使用 transformers.Trainer 训练我的模型时,我自定义了 compute_loss 函数和compute_metrics函数,我的模型是一个简单的二分类模型。

在自定义 compute_loss 时这样写的:

def compute_loss(self, model, inputs, return_outputs=False):"""重写 Trainer.compute_loss:1) 提取字典中的 images, bboxes, locs, labels 等2) 用 vision_encoder 先处理图像,得到特征3) 用下游 model 做预测4) 计算并返回 loss"""# 前向传播outputs, labels = model(**inputs)  # (bz, num_classes), or (bz*num_frames, num_classes)batch_size = inputs['labels'].shape[0]outputs = outputs.squeeze()  # (bz*num_frames)if batch_size == 1:outputs = outputs.unsqueeze(0)# 计算 lossloss = self.loss_func(outputs, labels.float())if self.state.global_step % 10 == 0 and self.state.global_step > 0:# 以50个step为间隔打印pred_probs = torch.sigmoid(outputs)preds = (pred_probs > 0.5).int()logger.info(f"[global_step={self.state.global_step}] preds={preds.tolist()} / labels={labels.tolist()} / loss={loss.item():.4f}")# compute metricaccuracy = accuracy_score(labels.cpu().numpy(), preds.cpu().numpy())precision = precision_score(labels.cpu().numpy(), preds.cpu().numpy())recall = recall_score(labels.cpu().numpy(), preds.cpu().numpy())logger.info(f"[global_step={self.state.global_step}] accuracy={accuracy:.4f} / precision={precision:.4f} / recall={recall:.4f}")# 返回 (loss, outputs) 或者只返回 lossreturn (loss, outputs) if return_outputs else loss

于是就出现了报错,像这样的:

File "/opt/conda/lib/python3.9/site-packages/transformers/trainer.py", line 3754, in predictoutput = eval_loop(File "/opt/conda/lib/python3.9/site-packages/transformers/trainer.py", line 3966, in evaluation_loopmetrics = self.compute_metrics(EvalPrediction(predictions=all_preds, label_ids=all_labels))File "/workspace/train/object_query/train.py", line 281, in compute_metricscorrect_num = preds == labels
ValueError: operands could not be broadcast together with shapes (11720,) (12104,)output = eval_loop(File "/opt/conda/lib/python3.9/site-packages/transformers/trainer.py", line 3966, in evaluation_loopmetrics = self.compute_metrics(EvalPrediction(predictions=all_preds, label_ids=all_labels))File "/workspace/train/object_query/train.py", line 281, in compute_metricscorrect_num = preds == labels
ValueError: operands could not be broadcast together with shapes (11720,) (12104,)

原因

该问题是 transformers.Trainer 内部有一段对outputs的操作造成的:

if isinstance(outputs, dict):logits = tuple(v for k, v in outputs.items() if k not in ignore_keys + ["loss"])
else:logits = outputs[1:]

这里当 outputs 不是字典时,会把第一个位置的元素offset掉。

解决

Refer to here
所以,我们应该在返回那里这样写:

return (loss, {"label": outputs}) if return_outputs else loss

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/18875.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

类与对象C++详解(上)

目录 1.类的定义 1.1 类定义格式 补充: struct与class的区别(c语言与c) 1.2 访问限定符 1.3 类域 2.实例化 3.对象大小 4.this指针 1.类的定义 1.1 类定义格式 class为定义类的关键字,Stack为类的名字,{}中为类的主体&…

LabVIEW 天然气水合物电声联合探测

天然气水合物被认为是潜在的清洁能源,其储量丰富,预计将在未来能源格局中扮演重要角色。由于其独特的物理化学特性,天然气水合物的探测面临诸多挑战,涉及温度、压力、电学信号、声学信号等多个参数。传统的人工操作方式不仅效率低…

Windows上安装Go并配置环境变量(图文步骤)

前言 1. 本文主要讲解的是在windows上安装Go语言的环境和配置环境变量; Go语言版本:1.23.2 Windows版本:win11(win10通用) 下载Go环境 下载go环境:Go下载官网链接(https://golang.google.cn/dl/) 等待…

神经网络常见激活函数 9-CELU函数

文章目录 CELU函数导函数函数和导函数图像优缺点pytorch中的CELU函数tensorflow 中的CELU函数 CELU 连续可微指数线性单元:CELU(Continuously Differentiable Exponential Linear Unit),是一种连续可导的激活函数,结合了 ELU 和 …

《安富莱嵌入式周报》第350期:Google开源Pebble智能手表,开源模块化机器人平台,开源万用表,支持10GHz HRTIM的单片机,开源CNC控制器

周报汇总地址:嵌入式周报 - uCOS & uCGUI & emWin & embOS & TouchGFX & ThreadX - 硬汉嵌入式论坛 - Powered by Discuz! 视频版: https://www.bilibili.com/video/BV1YPKEeyEeM/ 《安富莱嵌入式周报》第350期:Google开…

小米平板怎么和电脑共享屏幕

最近尝试使用小米平板和电脑屏幕分屏互联 发现是需要做特殊处理的,需要下载一款电脑安装包:小米妙享 关于这个安装包,想吐槽的是: 没有找到官网渠道,是通过其他网络方式查到下载的 不附录链接,原因是因为地…

(学习总结23)Linux 目录、通配符、重定向、管道、shell、权限与粘滞位

Linux 目录、通配符、重定向、管道、shell、权限与粘滞位 Linux 目录通配符重定向管道shell 介绍Linux 权限Linux 权限的概念用户切换命令 su Linux权限管理文件访问者的分类:常用文件类型与其标识符:文件基本权限和权限值的表示方法:更改文件…

深入解析操作系统控制台:阿里云Alibaba Cloud Linux(Alinux)的运维利器

作为一位个人开发者兼产品经理,我的工作日常紧密围绕着云资源的运维和管理。在这个过程中,操作系统扮演了至关重要的角色,而操作系统控制台则成为了我们进行系统管理的得力助手。本文将详细介绍阿里云的Alibaba Cloud Linux操作系统控制台的功…

Android10 音频参数导出合并

A10 设备录音时底噪过大,让音频同事校准了下,然后把校准好的参数需要导出来,集成到项目中,然后出包,导出方式在此记录 设备安装debug系统版本调试好后, adb root adb remount adb shell 进入设备目录 导…

dnslog+sqlmap外带数据

目录 爆库 爆表 爆列 爆数据 sqlmapDNSlog 外带参数 –dns-domain参数注入 –dns-domain参数为dnslog平台的域名(我们也可以使用本地) 爆库 python sqlmap.py -u "http://127.0.0.1/sqli/less-8/index.php/?id1" -techniqueB -dns-dom…

sql注入中information_schema被过滤的问题

目录 一、information_schema库的作用 二、获得表名 2.1 sys.schema_auto_increment_columns 2.2 schema_table_statistics 三、获得列名 join … using … order by盲注 子查询 在进行sql注入时,我们经常会使用information_schema来进行爆数据库名、表名、…

Flutter 常见布局模型

Flutter的常见的布局模型有容器(Container)、弹性盒子布局(Flex、Row、Column、Expanded)、流式布局(Wrap、Flow)、层叠布局(Stack、Position)、滚动布局(ListView、Grid…

深度学习框架探秘|TensorFlow:AI 世界的万能钥匙

在人工智能(AI)蓬勃发展的时代,各种强大的工具和框架如雨后春笋般涌现,而 TensorFlow 无疑是其中最耀眼的明星之一。它不仅被广泛应用于学术界的前沿研究,更是工业界实现 AI 落地的关键技术。今天,就让我们…

TypeScript 与后端开发Node.js

文章目录 一、搭建 TypeScript Node.js 项目 (一)初始化项目并安装相关依赖 1、创建项目目录并初始化2、安装必要的依赖包 (二)配置 TypeScript 编译选项(如模块解析方式适合后端) 二、编写服务器代码 &a…

DDD该怎么去落地实现(3)通用的仓库和工厂

通用的仓库和工厂 我有一个梦,就是希望DDD能够成为今后软件研发的主流,越来越多研发团队都转型DDD,采用DDD的设计思想和方法,设计开发软件系统。这个梦想在不久的将来是有可能达成的,因为DDD是软件复杂性的解决之道&a…

国家队出手!DeepSeek上线国家超算互联网平台!

目前,国家超算互联网平台已推出 DeepSeek – R1 模型的 1.5B、7B、8B、14B 版本,后续还会在近期更新 32B、70B 等版本。 DeepSeek太火爆了!在这个春节档,直接成了全民热议的话题。 DeepSeek也毫无悬念地干到了全球增速最快的AI应用。这几天,国内的云计算厂家都在支持Dee…

内容中台驱动企业数字化内容管理高效协同架构

内容概要 在数字化转型加速的背景下,企业对内容管理的需求从单一存储向全链路协同演进。内容中台作为核心支撑架构,通过统一的内容资源池与智能化管理工具,重塑了内容生产、存储、分发及迭代的流程。其核心价值在于打破部门壁垒,…

算法1-1 玩具谜题

题目描述 小南有一套可爱的玩具小人,它们各有不同的职业。 有一天,这些玩具小人把小南的眼镜藏了起来。小南发现玩具小人们围成了一个圈,它们有的面朝圈内,有的面朝圈外。如下图: 这时 singer 告诉小南一个谜题&…

vs2022支持.netframework4.0

下载nuget包 .netframework4.0 解压nuget 复制到C:\Program Files (x86)\Reference Assemblies\Microsoft\Framework\.NETFramework 参考 https://www.cnblogs.com/bdqczhl/p/18670152 https://blog.csdn.net/xiaomeng1998_/article/details/135979884

数据治理常用的开源项目有哪些?

数据治理是企业在大数据时代中确保数据质量、安全性和可用性的关键环节。开源项目在数据治理中扮演着重要角色,提供了灵活、经济高效且功能强大的解决方案。以下是一些常用的开源数据治理项目: Apache Atlas: 功能:元数据管理、数…