理解Pytorch中的collate_fn函数

PyTorch中的DataLoader是最常用的类之一,这个类有很多参数(14 个),但大多数情况下,你可能只会使用其中的三个:dataset、shuffle 和 batch_size。其中collate_fn是比较少用的函数,这对初学者来说是一个容易混淆的概念。下面将简要探讨 PyTorch 如何创建批次,并了解如何根据我们的需求修改其默认行为。

批处理

首先,先创建一个数据,用data变量标视

import torch
from torch.utils.data import DataLoader
import numpy as npdata = np.array([[0.1, 7.4, 0],[-0.2, 5.3, 0],[0.2, 8.2, 1],[0.2, 7.7, 1]])
print(data)

如果我们加载一个批次出来( shuffle=False以消除随机性):

loader = DataLoader(data, batch_size=2, shuffle=False)
batch = next(iter(loader))
print(batch)# tensor([[ 0.1000,  7.4000,  0.0000],
#         [-0.2000,  5.3000,  0.0000]], dtype=torch.float64)

结果符合预期的,我们来解释一下已经做了什么:

  • 加载器从数据集中选择了 2 个样本。
  • 这些样本被转换为张量(2 个大小为 3 的样本)。
  • 创建并返回一个新的张量 (2x3)。

默认设置还允许我们使用字典。 让我们看一个例子:

from pprint import pprint
# now dataset is a list of dicts
dict_data = [{'x1': 0.1, 'x2': 7.4, 'y': 0},{'x1': -0.2, 'x2': 5.3, 'y': 0},{'x1': 0.2, 'x2': 8.2, 'y': 1},{'x1': 0.2, 'x2': 7.7, 'y': 10},
]
pprint(dict_data)
# [{'x1': 0.1, 'x2': 7.4, 'y': 0},
# {'x1': -0.2, 'x2': 5.3, 'y': 0},
# {'x1': 0.2, 'x2': 8.2, 'y': 1},
# {'x1': 0.2, 'x2': 7.7, 'y': 10}]loader = DataLoader(dict_data, batch_size=2, shuffle=False)
batch = next(iter(loader))
pprint(batch)
# {'x1': tensor([ 0.1000, -0.2000], dtype=torch.float64),
#  'x2': tensor([7.4000, 5.3000], dtype=torch.float64),
#  'y': tensor([0, 0])}

Dataloader简单易用,可以正确地从字典列表中重新打包数据。 当你的数据采用 JSON格式时,此功能非常方便。

自定义collate函数

Dataloader默认设置能覆盖大部分场景的数据读取,但默认设置有一个很大的限制——批数据必须处于同一维度。 假设我们有一个 NLP 任务,并且数据是分词后的文本。

# values are token indices but it does not matter - it can be any kind of variable-size data
nlp_data = [{'tokenized_input': [1, 4, 5, 9, 3, 2],'label':0},{'tokenized_input': [1, 7, 3, 14, 48, 7, 23, 154, 2],'label':0},{'tokenized_input': [1, 30, 67, 117, 21, 15, 2],'label':1},{'tokenized_input': [1, 17, 2],'label':0},
]
loader = DataLoader(nlp_data, batch_size=2, shuffle=False)
batch = next(iter(loader))

这样强行去压成一个batch存储,会引发错误:

/usr/local/lib/python3.7/dist-packages/torch/utils/data/_utils/collate.py in default_collate(batch)80         elem_size = len(next(it))81         if not all(len(elem) == elem_size for elem in it):
---> 82             raise RuntimeError('each element in list of batch should be of equal size')83         transposed = zip(*batch)84         return [default_collate(samples) for samples in transposed]RuntimeError: each element in list of batch should be of equal size

报错信息显示不能创建非矩形张量。顺便说一句,可以看到触发错误的是 default_collate函数。

如何修改? 有两种解决方案:

  • 将整个数据集填充到最长的样本。
  • 在Batch data创建期间进行动态填充。

第一个方法最简单,但是非常耗内存,极端条件下,我们有1万条数据,其中9999条数据长度是10,而只有1条数据长度是1000,那么所有的数据都需要pad数值,使得长度填充到1000。这样有99%的内存占用都是无意义的。

pad-max

另一种方法是动态填充数据。 当选择这一个batch的样本时,我们只需要将数据填充到这一个batch最长的样本长度即可。另外,将数据按照长度进行排序,则填充的数量将是最小的。 如果有一些非常长的序列,它们只会影响它们的这一个batch的效率,而不是整个数据集。

pad-batch

具体如何实现呢?

from torch.nn.utils.rnn import pad_sequence #(1)def custom_collate(data): #(2)inputs = [torch.tensor(d['tokenized_input']) for d in data] #(3)labels = [d['label'] for d in data]inputs = pad_sequence(inputs, batch_first=True) #(4)labels = torch.tensor(labels) #(5)return { #(6)'tokenized_input': inputs,'label': labels}loader = DataLoader(nlp_data, batch_size=2, shuffle=False, collate_fn=custom_collate
) #(7)iter_loader = iter(loader)
batch1 = next(iter_loader)
pprint(batch1)
batch2 = next(iter_loader)
pprint(batch2)# {'label': tensor([0, 0]),
#  'tokenized_input': tensor([
#   [  1,   4,   5,   9,   3,   2,   0,   0,   0],
#   [  1,   7,   3,  14,  48,   7,  23, 154,   2]
# ])}# {'label': tensor([1, 0]),
#  'tokenized_input': tensor([
#   [  1,  30,  67, 117,  21,  15,   2],
#   [  1,  17,   2,   0,   0,   0,   0]])}

代码功能如下:

  • 我们使用 pad_sequence进行填充
  • custom_collate作为参数传递给DataLoader
  • 在运行时对inputs进行动态填充

总结

collate_fn是一个很少用的函数,但对提升训练效率有很大的帮助。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/402347.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Linux线程间通信学习记录(线程同步)

0.线程间通信的方法 (1).全局变量(要结合同步机制) (2).信号量 (3).P操作 (4).V操作 一.线程同步 同步:指的是多个任务按照约定的先后次序相互…

Visual C++ 2010 学习版

这个版本很好用。 在这里放一个链接,做个备份。 这个版本是承前启后的版本,非常的重要。 一、使用VC2010 这个版本创建的解决方案可以在VS2010~VS2022版本中打开,反之也行。 二、使用VC2010 可以编绎VC6.0 ~VC2008的项目。可以使用现成的…

灵办AI助手Chrome插件全面评测:PC Web端的智能办公利器

探索灵办AI助手在Mac OS上的高效表现,支持多款主流浏览器,助你轻松应对办公挑战 文章目录 探索灵办AI助手在Mac OS上的高效表现,支持多款主流浏览器,助你轻松应对办公挑战摘要引言开发环境介绍核心功能评测1. 网页翻译与双语对照 …

Rancher 使用 Minio 备份 Longhorn 数据卷

0. 概述 Longhorn 支持备份到 NFS 或者 S3, 而 MinIO 就是符合 S3 的对象存储服务。通过 docker 部署 minio 服务,然后在 Longhorn UI 中配置备份服务即可。 1. MinIO 部署 1.1 创建备份目录 mkdir -p /home/longhorn-backup/minio/data mkdir -p /home/longhor…

RCE的另外一些绕过练习

目录 被过滤了flag怎么办 方法 结果 过滤了flag、php、system 方法一 结果 ​编辑 方法二 过滤了很多但是主要的就是过滤了空格 和 注意一下这个就行 方法一 方法二 相对于上面一道题来说多过滤了一个括号 方法一 被过滤了flag怎么办 <?php error_reportin…

Python3网络爬虫开发实战(10)模拟登录(需补充账号池的构建)

文章目录 一、基于 Cookie 的模拟登录二、基于 JWT 模拟登入三、账号池四、基于 Cookie 模拟登录爬取实战五、基于JWT 的模拟登录爬取实战六、构建账号池 很多情况下&#xff0c;网站的一些数据需要登录才能查看&#xff0c;如果需要爬取这部分的数据&#xff0c;就需要实现模拟…

K8S - ConfigMap的简介和使用

什么是configMap Kubernetes中的ConfigMap 是用于存储非敏感数据的API对象&#xff0c;用于将配置数据与应用程序的镜像分离。ConfigMap可以包含键值对、文件或者环境变量等配置信息&#xff0c;应用程序可以通过挂载ConfigMap来访问其中的数据&#xff0c;从而实现应用配置的…

ubuntu20 lightdm无法自动登录进入桌面

现象&#xff1a;在rk3568的板子上自己做了一个Ubuntu 20.04的桌面系统。配置lightdm自动登录桌面&#xff0c;配置方法如下&#xff1a; $ vim /etc/lightdm/lightdm.conf [Seat:*] user-sessionxubuntu autologin-userusername #修改成自动登录的用户名 greeter-show-m…

38-PCB布局实战实战及优化

1.先对布局好的器件进行锁定 1.根据模块化布局 2.电容尽量靠近ic附近&#xff0c;可以起到很好的滤波效果 3.复位按键尽量摆在容易按键的地方&#xff0c;比如周围 。。。。 最后进行对齐

【OCR 学习笔记】二值化——局部阈值方法

二值化——局部阈值方法 自适应阈值算法Niblack算法Sauvola算法 自适应阈值算法 自适应阈值算法1用到了积分图&#xff08;Integral Image&#xff09;的概念。积分图中任意一点 ( x , y ) (x,y) (x,y)的值是从图左上角到该点形成的矩形区域内所有值的和。即&#xff1a; I (…

模板[C++]

目录 1.&#x1f680;泛型编程&#x1f680; 2.&#x1f680;函数模板&#x1f680; 2.1 ✈️函数模板概念✈️ 2.2 ✈️函数模板格式✈️ 2.3✈️函数模板的原理✈️ 2.4 ✈️函数模板的实例化✈️ 2.5 ✈️模板参数的匹配原则✈️ 3.&#x1f680;类模板&#x1f680…

文件中找TopK问题 的详细讲解

一&#xff1a;问题&#xff1a; 从一个包含10000整数的文件中找出最大的前10个数。 二&#xff1a;方法&#xff1a; 1&#xff1a;先直接拿文件的前10个数&#xff0c;建造一个小堆 2&#xff1a;再依次读取文件中&#xff0c;剩下的数&#xff0c;比堆顶大&#xff0c;则…

学习记录第二十九天

信号量————来描述可使用资源的个数 信号量&#xff08;Semaphore&#xff09;是一种用于控制多个进程或线程对共享资源访问的同步机制。在C语言中&#xff0c;通常我们会使用POSIX线程&#xff08;pthread&#xff09;库来实现信号量的操作 信号量有两个主要操作&#xf…

C语言 ——— 位段(位域)

目录 什么是位段 位段的内存分配 什么是位段 位段的声明和结构体是类似的 但有两个不同&#xff1a; 1. 位段的成员必须是整型家族&#xff1a; int&#xff08;整型&#xff09; &#xff0c;unsigend int &#xff08;无符号整型&#xff09;&#xff0c;sigend int&…

【初阶数据结构题目】32. 希尔排序

文章目录 希尔排序希尔排序的时间复杂度计算 希尔排序 希尔排序法又称缩小增量法。希尔排序法的基本思想是&#xff1a;先选定一个整数&#xff08;通常是gap n/31&#xff09;&#xff0c;把待排序文件所有记录分成各组&#xff0c;所有的距离相等的记录分在同一组内&#x…

歌曲爬虫下载

本次编写一个程序要爬取歌曲音乐榜https://www.onenzb.com/ 里面歌曲。有帮到铁子的可以收藏和关注起来&#xff01;&#xff01;&#xff01;废话不多说直接上代码。 1 必要的包 import requests from lxml import html,etree from bs4 import BeautifulSoup import re impo…

Qt作业合集

8.14作业 设置窗口&#xff0c;按钮&#xff0c;标签&#xff0c;行编辑器&#xff0c;实现快递速运登录页面 #include "mywidget.h"MyWidget::MyWidget(QWidget *parent): QWidget(parent) {//窗口//设置窗口的标题this->setWindowTitle("邮递系统")…

蚂蚁AL1 15.6T 创新科技的新典范

● 哈希率&#xff1a;算力达到15.6T&#xff08;相当于15600G&#xff09;&#xff0c;即每秒能够进行15.6万亿次哈希计算&#xff0c;在同类产品中算力较为出色&#xff0c;能提高WA掘效率。 ● 功耗&#xff1a;功耗为3510W&#xff0c;虽然数值看似不低&#xff0c;但结合其…

内存泄漏之如何使用Visual Studio的调试工具跟踪内存泄漏?

使用Visual Studio的调试工具跟踪内存泄漏是一个系统性的过程&#xff0c;主要包括启用内存泄漏检测、运行程序、分析内存使用情况以及定位泄漏源等步骤。 Visual Studio提供了多种方式来检测内存泄漏&#xff0c;你可以根据自己的需求选择合适的方法。 注意&#xff1a;下面…

【TiDB】10-对 TiDB 进行 TPC-C 测试

目录 1、安装bench工具 2、插入数据 3、运行测试 4、测试结果分析 4.1、总体性能概览 4.2、事务类型详细性能 4.3、错误事务分析 4.4、结论与建议 5、清理测试数据 TPC-C 是一个对 OLTP&#xff08;联机交易处理&#xff09;系统进行测试的规范&#xff0c;使用一个商…