Pytorch指定数据加载器使用子进程

torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True,num_workers=4, pin_memory=True)

num_workers 参数是 DataLoader 类的一个参数,它指定了数据加载器使用的子进程数量。通过增加 num_workers 的数量,可以并行地读取和预处理数据,从而提高数据加载的速度。

通常情况下,增加 num_workers 的数量可以提高数据加载的效率,因为它可以使数据加载和预处理工作在多个进程中同时进行。然而,当 num_workers 的数量超过一定阈值时,增加更多的进程可能不会再带来更多的性能提升,甚至可能会导致性能下降。

这是因为增加 num_workers 的数量也会增加进程间通信的开销。当 num_workers 的数量过多时,进程间通信的开销可能会超过并行化所带来的收益,从而导致性能下降。

此外,还需要考虑到计算机硬件的限制。如果你的计算机 CPU 核心数量有限,增加 num_workers 的数量也可能会导致性能下降,因为每个进程需要占用 CPU 核心资源。

因此,对于 num_workers 参数的设置,需要根据具体情况进行调整和优化。通常情况下,一个合理的 num_workers 值应该在 2 到 8 之间,具体取决于你的计算机硬件配置和数据集大小等因素。在实际应用中,可以通过尝试不同的 num_workers 值来找到最优的配置。

综上所述,当 num_workers 的值从 4 增加到 8 时,如果你的计算机硬件配置和数据集大小等因素没有发生变化,那么两者之间的性能差异可能会很小,或者甚至没有显著差异。

测试代码如下

import torch
import torchvision
import matplotlib.pyplot as plt
import torchvision.models as models
import torch.nn as nn
import torch.optim as optim
import torch.multiprocessing as mp
import timeif __name__ == '__main__':mp.freeze_support()train_on_gpu = torch.cuda.is_available()if not train_on_gpu:print('CUDA is not available. Training on CPU...')else:print('CUDA is available! Training on GPU...')device = torch.device("cuda" if torch.cuda.is_available() else "cpu")batch_size = 4# 设置数据预处理的转换transform = torchvision.transforms.Compose([torchvision.transforms.Resize((512,512)),  # 调整图像大小为 224x224torchvision.transforms.ToTensor(),  # 转换为张量torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # 归一化])dataset = torchvision.datasets.ImageFolder('C:\\Users\\ASUS\\PycharmProjects\\pythonProject1\\cats_and_dogs_train',transform=transform)val_ratio = 0.2val_size = int(len(dataset) * val_ratio)train_size = len(dataset) - val_sizetrain_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])train_dataset = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True,num_workers=4, pin_memory=True)val_dataset = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=True,num_workers=4, pin_memory=True)model = models.resnet18()num_classes = 2for param in model.parameters():param.requires_grad = Falsemodel.fc = nn.Sequential(nn.Dropout(),nn.Linear(model.fc.in_features, num_classes),nn.LogSoftmax(dim=1))optimizer = optim.Adam(model.parameters(), lr=0.001)criterion = nn.CrossEntropyLoss().to(device)model.to(device)filename = "recognize_cats_and_dogs.pt"def save_checkpoint(epoch, model, optimizer, filename):checkpoint = {'epoch': epoch,'model_state_dict': model.state_dict(),'optimizer_state_dict': optimizer.state_dict(),'loss': loss,}torch.save(checkpoint, filename)num_epochs = 3train_loss = []for epoch in range(num_epochs):running_loss = 0correct = 0total = 0epoch_start_time = time.time()for i, (inputs, labels) in enumerate(train_dataset):# 将数据放到设备上inputs, labels = inputs.to(device), labels.to(device)# 前向计算outputs = model(inputs)# 计算损失和梯度loss = criterion(outputs, labels)optimizer.zero_grad()loss.backward()# 更新模型参数optimizer.step()# 记录损失和准确率running_loss += loss.item()train_loss.append(loss.item())_, predicted = torch.max(outputs.data, 1)correct += (predicted == labels).sum().item()total += labels.size(0)accuracy_train = 100 * correct / total# 在测试集上计算准确率with torch.no_grad():running_loss_test = 0correct_test = 0total_test = 0for inputs, labels in val_dataset:inputs, labels = inputs.to(device), labels.to(device)outputs = model(inputs)loss = criterion(outputs, labels)running_loss_test += loss.item()_, predicted = torch.max(outputs.data, 1)correct_test += (predicted == labels).sum().item()total_test += labels.size(0)accuracy_test = 100 * correct_test / total_test# 输出每个 epoch 的损失和准确率epoch_end_time = time.time()epoch_time = epoch_end_time - epoch_start_timeprint("Epoch [{}/{}], Time: {:.4f}s, Loss: {:.4f}, Train Accuracy: {:.2f}%, Loss: {:.4f}, Test Accuracy: {:.2f}%".format(epoch + 1, num_epochs,epoch_time,running_loss / len(val_dataset),accuracy_train, running_loss_test / len(val_dataset), accuracy_test))save_checkpoint(epoch, model, optimizer, filename)plt.plot(train_loss, label='Train Loss')# 添加图例和标签plt.legend()plt.xlabel('Epochs')plt.ylabel('Loss')plt.title('Training Loss')# 显示图形plt.show()

不同num_workers的结果如下

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/174716.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

分布式:一文吃透分布式锁,Redis/Zookeeper/MySQL实现

目录 一、项目准备spring项目数据库 二、传统锁演示超卖现象使用JVM锁解决超卖解决方案JVM失效场景 使用一个SQL解决超卖使用mysql悲观锁解决超卖使用mysql乐观锁解决超卖四种锁比较Redis乐观锁集成Redis超卖现象redis乐观锁解决超卖 三、分布式锁概述四、Redis分布式锁实现方案…

Linux 文件系统简介

文章目录 一、磁盘简介1.1 简介1.2 机械硬盘与固态硬盘1.2.1 机械磁盘(HDD)1.2.2 固态磁盘(SSD)1.2.3 I/O操作 二、文件系统简介2.1. 简介2.2 文件系统特点2.3 Linux文件系统 三、文件数据存储方式3.1 连续存储3.2 链接表存储3.3 …

前端知识与基础应用#2

标签的分类 关于标签我们可以分为 : 单标签:img, br hr 双标签:a,h,div 按照属性可分为: 块儿标签(自己独自占一行):h1-h6, p,div 行内(内联)标签&#xff08…

One-to-N N-to-One: Two Advanced Backdoor Attacks Against Deep Learning Models

One-to-N & N-to-One: Two Advanced Backdoor Attacks Against Deep Learning Models----《一对N和N对一:针对深度学习模型的两种高级后门攻击》 1对N: 通过控制同一后门的不同强度触发多个后门 N对1: 只有当所有N个后门都满足时才会触发…

3.5每日一题(求齐次方程组的特解)

1、判断类型选择方法:看出为齐次方程(次幂都一样) 2、 化为变量可分离;按变量可分离的方法求出通解(此题等式两边同时除以 x ) 3、把x1,y0带入通解,定常数C,求出特解 …

用大白话聊聊SpringBoot的自动配置原理(面试题详解)

首先,SpringBoot的自动配置不等于自动装配! 自动配置是Auto-Configuration,针对的是SpringBoot中的配置类, 而自动装配是Autowire,针对的是Spring中的依赖注入。 进入主题: 自动配置简单来说就是自动去把…

java八股文(基础篇)

面向过程和面向对象的区别 面向过程:在解决问题时,特别自定义函数编写一步一步的步骤解决问题。 面向对象:其特点就是 继承,多态,继承,在解决问题时,不再注重函数的编写,而在于注重…

Spring Boot 3系列之一(初始化项目)

近期,JDK 21正式发布,而Spring Boot 3也推出已有一段时间。作为这两大技术领域的新一代标杆,它们带来了许多令人振奋的新功能和改进。尽管已有不少博客和文章对此进行了介绍,但对于我们这些身处一线的开发人员来说,有些…

【Linux】从零开始学习Linux基本指令(三)

🚩纸上得来终觉浅, 绝知此事要躬行。 🌟主页:June-Frost 🚀专栏:Linux入门 🔥该文章主要了解Linux操作系统下的基本指令。 ⚡️该篇为Linux指令部分的终章,如果您想了解前两篇文章的…

【Docker】Linux网络命名空间

命名空间 Namespace是Linux提供的一种对于系统全局资源的隔离机制;从进程的视角来看,同一个namespace中的进程看到的是该namespace自己独立的一份全局资源,这些资源的变化只在本namespace中可见,对其他namespace没有影响。容器就…

Linux学习第26天:异步通知驱动开发: 主动

Linux版本号4.1.15 芯片I.MX6ULL 大叔学Linux 品人间百味 思文短情长 在正式开启今天的学习前,讲一讲为什么标题中加入了【主动】俩字。之前学习的阻塞和非阻塞IO,都是在被动的接受应用程序的操作。而今天的学…

rust入门

目录 一,输入输出 二,函数 1,main函数 2,普通函数 3,库函数 4,常用库函数 三,变量 1,变量绑定、let、mut 2,变量作用域 四,数据结构 1&#xff0c…

风云七剑攻略,最强阵容搭配

今天的风云七剑攻略最强阵容搭配给大家推荐以神仙斋减怒回血为主的阵容。 关注【娱乐天梯】,获取内部福利号 首先,这个角色在这个阵容当中,所有的角色当中,他的输出系数是最高的,已经达到了200%的层次,而且…

商业模式画布的9大模块全解读,产品经理不可不知!

“商场如战场”,在当今瞬息万变的商业环境中,创造出独特且创新的商业模式是每个企业家、策略家和决策者的首要任务。为了在激烈的市场竞争中取得优势,我们需要一个强大且直观的工具来帮助我们规划和塑造公司的商业模式,这个经常被…

H5游戏源码分享-跳得更高

H5游戏源码分享-跳得更高 控制跳动踩到云朵上 <!DOCTYPE html> <html> <head><meta http-equiv"Content-Type" content"text/html; charsetUTF-8"><meta http-equiv"Content-Type" content"text/html;"&g…

【.NET Core】创建一个在后台运行的控制台程序(ConsoleApp)

文章目录 1. 添加Nuget包2. 修改Program.cs3. 添加TestService 借助.NET的通用主机&#xff08;IHostBuilder&#xff09;可以轻易创建一个可以执行后台任务的程序 1. 添加Nuget包 Microsoft.Extensions.Hosting 2. 修改Program.cs 通过Host获取IHostService&#xff0c;然…

[UDS] --- ECUReset 0x11

1 0x11功能描述 根据ISO14119-1标准中所述&#xff0c;诊断服务11主要用于Client向Server(ECU)请求重启行为。该重启行为将会导致Server复位回归到特定的初始状态&#xff0c;具体是什么初始状态取决于Client的请求行为。 2 0x11应用场景 一般而言&#xff0c;对于11诊断服务…

Ansible的安装和部署

目录 1.Ansible的安装 2.构建Ansible清单 直接书写受管主机名或ip 设定受管主机的组[组名称] 主机规格的范围化操作 指定其他清单文件 ansible命令指定清单的正则表达式 3.Ansible配置文件参数详解 配置文件的分类与优先级 常用配置参数 4.构建用户级Ansible操作环…

goland无法调试问题解决

goland 无法调试问题解决 golang 版本升级后&#xff0c;goland 无法进行调试了 首先请看自己下载的版本是否有误 1.apple系 M系列芯片的 arm64版本 2.apple系 intel系列芯片的x86_64 3.windows系 intel解决如下&#xff1a; 查看gopath ericsanchezErics-Mac-mini gww-api…

docker 安装minio,访问地址进不去

文章目录 黑马头条P37docker安装minio文图一、启动后页面一直是加载状态进不去 黑马头条P37docker安装minio文图 一、启动后页面一直是加载状态进不去 通过docker logs -f (容器id)查看日志 通过这个报错信息&#xff0c;得知最近minio 升级&#xff0c;一些启动信息和之前不…