如何指定多块GPU卡进行训练-数据并行

训练代码:

train.py

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import torch.nn.functional as F# 假设我们有一个简单的文本数据集
class TextDataset(Dataset):def __init__(self, texts, labels, vocab):self.texts = textsself.labels = labelsself.vocab = vocabdef __len__(self):return len(self.texts)def __getitem__(self, idx):text = self.texts[idx]label = self.labels[idx]# 将文本转换为索引text_indices = [self.vocab.get(word, self.vocab['<UNK>']) for word in text.split()]return torch.tensor(text_indices, dtype=torch.long), torch.tensor(label, dtype=torch.long)# 定义一个简单的LSTM分类器
class LSTMClassifier(nn.Module):def __init__(self, vocab_size, embedding_dim, hidden_dim, output_dim):super(LSTMClassifier, self).__init__()self.embedding = nn.Embedding(vocab_size, embedding_dim)self.lstm = nn.LSTM(embedding_dim, hidden_dim, batch_first=True)self.fc = nn.Linear(hidden_dim, output_dim)def forward(self, x):embedded = self.embedding(x)_, (hidden, _) = self.lstm(embedded)output = self.fc(hidden[-1])return output# 构建词汇表
vocab = {'<PAD>': 0, '<UNK>': 1, 'I': 2, 'love': 3, 'this': 4, 'movie': 5, 'is': 6, 'terrible': 7}
vocab_size = len(vocab)# 示例数据
texts = ["I love this movie", "This movie is terrible"]
labels = [1, 0]  # 1表示正面情感,0表示负面情感# 创建数据集和数据加载器
dataset = TextDataset(texts, labels, vocab)
dataloader = DataLoader(dataset, batch_size=2, shuffle=True, collate_fn=lambda x: (torch.nn.utils.rnn.pad_sequence([item[0] for item in x], batch_first=True), torch.stack([item[1] for item in x])))# 实例化模型
embedding_dim = 50
hidden_dim = 50
output_dim = 2
model = LSTMClassifier(vocab_size, embedding_dim, hidden_dim, output_dim)# 使用DataParallel包装模型
model = nn.DataParallel(model)# 将模型移动到GPU
model = model.cuda()# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)# 训练步骤
model.train()
for epoch in range(10):  # 训练10个epochfor inputs, labels in dataloader:inputs, labels = inputs.cuda(), labels.cuda()optimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()print(f"Epoch {epoch+1}, Loss: {loss.item()}")print("训练完成")# 测试模型
model.eval()
test_texts = ["I love this movie", "This movie is terrible"]
test_dataset = TextDataset(test_texts, [1, 0], vocab)
test_dataloader = DataLoader(test_dataset, batch_size=2, shuffle=False, collate_fn=lambda x: (torch.nn.utils.rnn.pad_sequence([item[0] for item in x], batch_first=True), torch.stack([item[1] for item in x])))with torch.no_grad():for inputs, labels in test_dataloader:inputs, labels = inputs.cuda(), labels.cuda()outputs = model(inputs)predictions = torch.argmax(F.softmax(outputs, dim=1), dim=1)print(f"Predictions: {predictions.cpu().numpy()}, Labels: {labels.cpu().numpy()}")

执行命令:

  • export CUDA_VISIBLE_DEVICES=0,2
  • python train.py

GPU监控

训练前
在这里插入图片描述
训练中
在这里插入图片描述
Epoch 1, Loss: 0.7198400497436523
Epoch 2, Loss: 0.6889444589614868
Epoch 3, Loss: 0.6591541767120361
Epoch 4, Loss: 0.630306601524353
Epoch 5, Loss: 0.6022476553916931
Epoch 6, Loss: 0.5748419761657715
Epoch 7, Loss: 0.5479871034622192
Epoch 8, Loss: 0.5216072201728821
Epoch 9, Loss: 0.4956483840942383
Epoch 10, Loss: 0.47007784247398376
训练完成
Predictions: [1 0], Labels: [1 0]

结论

export CUDA_VISIBLE_DEVICES=0,2与nn.DataParallel(model)结合的方法是正确的

为什么需要指定 CUDA_VISIBLE_DEVICES

  • 在多GPU系统中,默认情况下,PyTorch 会尝试使用所有可用的GPU进行训练。
  • 通过设置 CUDA_VISIBLE_DEVICES 环境变量,用于控制哪些GPU对当前进程可见,PyTorch 只会使用这些可见的GPU进行训练。
  • 通过设置环境变量,你可以在不修改代码的情况下控制使用的GPU。这使得代码更加简洁和通用,不需要在代码中硬编码GPU的选择逻辑。
    总的来说:通过设置 CUDA_VISIBLE_DEVICES 环境变量,你可以灵活地控制哪些GPU对当前进程可见,从而避免资源冲突、简化代码并更好地管理多GPU资源。这是使用 torch.nn.DataParallel 进行多GPU训练时的一种常见做法。

nn.DataParallel原理是什么

nn.DataParallel 是 PyTorch 中用于多 GPU 并行计算的一个模块。它的主要原理是将输入数据分割成多个子集,并将这些子集分配到不同的 GPU 上进行并行计算。具体来说,nn.DataParallel 的工作流程如下:

  • 模型复制:首先,nn.DataParallel 会将模型复制到每个 GPU 上。这意味着每个 GPU 都会有一份完整的模型副本。
  • 数据分割:输入数据会被分割成多个子集,每个子集会被分配到一个 GPU 上。通常,这个分割是按批次(batch)维度进行的。
  • 并行计算:每个 GPU 使用其本地的模型副本对分配到的子集进行前向传播和后向传播计算。
  • 梯度汇总:在所有 GPU 上完成计算后,nn.DataParallel 会将每个 GPU 计算得到的梯度汇总到主 GPU 上(通常是 GPU 0)。
  • 参数更新:主 GPU 汇总梯度后,使用这些梯度更新模型参数。更新后的参数会同步到所有 GPU 上的模型副本。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/377282.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

jmeter分布式(四)

一、gui jmeter的gui主要用来调试脚本 1、先gui创建脚本 先做一个脚本 演示&#xff1a;如何做混合场景的脚本&#xff1f; 用211的业务比例 ①启动数据库服务 数据库服务&#xff1a;包括mysql、redis mysql端口默认3306 netstat -lntp | grep 3306处于监听状态&#xf…

【C++】—— 初识C++

【C】—— 初识C 一、什么是 C二、C 的发展历史三、C 版本更新四、C 的重要性五、C 在工作领域中的运用六、C 书籍推荐&#xff1a; 一、什么是 C C语言 是结构化和模块化的语言&#xff0c;适合处理较小规模的程序。对于复杂的问题&#xff0c;规模较大的程序&#xff0c;需要…

使用嵌入式知识打造智能手环:nRF52蓝牙开发实战(C++/BLE/传感器)

项目概述 现代人越来越注重健康管理&#xff0c;智能穿戴设备应运而生。本项目旨在利用低功耗蓝牙芯片nRF52832&#xff0c;结合加速度计、心率传感器、陀螺仪等传感器&#xff0c;开发一款功能完善、性能稳定的智能运动手环。该手环能够实时采集用户的运动数据和生理指标&…

vue3 - vue项目自动检测更新

vue3 GitHub Demo 地址 vue3在线预览 vue2 GitHub Demo 地址 vue2 在线预览 web项目当页面检测到需要更新&#xff0c;然后弹框提示是否更新&#xff08;刷新页面&#xff09;这种可以通过纯前端实现也可以通过接口实现 接口实现&#xff1a;通过调用接口轮询和本地的版本号比…

护网HW面试——redis利用方式即复现

参考&#xff1a;https://xz.aliyun.com/t/13071 面试中经常会问到ssrf的打法&#xff0c;讲到ssrf那么就会讲到配合打内网的redis&#xff0c;本篇就介绍redis的打法。 未授权 原理&#xff1a; Redis默认情况下&#xff0c;会绑定在0.0.0.0:6379&#xff0c;如果没有采用相关…

FPGA设计之跨时钟域(CDC)设计篇(1)----亚稳态到底是什么?

1、什么是亚稳态? 在数字电路中,如果数据传输时不满足触发器FF的建立时间要求Tsu和保持时间要求Th,就可能产生亚稳态(Metastability),此时触发器的输出端(Q端)在有效时钟沿之后比较长的一段时间都会处于不确定的状态(在0和1之间振荡),而不是等于数据输入端(D端)的…

强制升级最新系统,微软全面淘汰Win10和部分11用户

说出来可能不信&#xff0c;距离 Windows 11 正式发布已过去整整三年时间&#xff0c;按理说现在怎么也得人均 Win 11 水平了吧&#xff1f; 然而事实却是&#xff0c;三年时间过去 Win 11 占有率仅仅突破到 29%&#xff0c;也就跳起来摸 Win 10 屁股的程度。 2024 年 6 月 Wi…

功率继电器【HF46F】

目的&#xff1a;通过单片机控制继电器动作。 原理图如下&#xff0c;原理图中使用的继电器为HF46F5H&#xff0c; 上述原理图的电路原理&#xff1a; 在这个电路图中&#xff0c;电源开关相关的部分包括一个电源开关、一个三极管Q1、一个二极管D2和一个继电器K1。当电源开关…

阿里云ECS服务器安装jdk并运行jar包,访问成功详解

安装 OpenJDK 8 使用 yum 包管理器安装 OpenJDK 8 sudo yum install -y java-1.8.0-openjdk-devel 验证安装 安装完成后&#xff0c;验证 JDK 是否安装成功&#xff1a; java -version设置 JAVA_HOME 环境变量&#xff1a; 为了确保系统中的其他应用程序可以找到 JDK&…

开源必看!50 多个本地运行 LLM 的开源选项

在我之前的文章中&#xff0c;我讨论了使用本地托管的开源权重 LLM 的好处&#xff0c;例如数据隐私和成本节约。通过主要使用免费模型并偶尔切换到 GPT-4&#xff0c;我的月度开支从 20 美元降至 0.50 美元。设置端口转发到本地 LLM 服务器是移动访问的免费解决方案。 有许多…

WSL2 的安装与运行 Linux 系统

前言 适用于 Linux 的 Windows 子系统 (WSL) 是 Windows 的一项功能&#xff0c;允许开发人员在 Windows 系统上直接安装并使用 Linux 发行版。不用进行任何修改&#xff0c;也无需承担传统虚拟机或双启动设置的开销。 可以将 WSL 看作也是一个虚拟机&#xff0c;但是它更为便…

let/const/var的区别及理解

在JavaScript中&#xff0c;let、const 和 var 是用来声明变量的关键字&#xff0c;但它们之间在作用域、变量提升、重复声明等方面存在区别&#xff0c;详细情况如下: 1. let、const、var 的区别 (1) 块级作用域 let 和 const&#xff1a;具有块级作用域&#xff0c;由 {} 包…

记录些Redis题集(2)

Redis 的多路IO复用 多路I/O复用是一种同时监听多个文件描述符&#xff08;如Socket&#xff09;的状态变化&#xff0c;并能在某个文件描述符就绪时执行相应操作的技术。在Redis中&#xff0c;多路I/O复用技术主要用于处理客户端的连接请求和读写操作&#xff0c;以实现高并发…

Redis 中String类型操作命令(命令演示,时间复杂度,返回值,注意事项)

String 类型 文章目录 String 类型set 命令get 命令mset 命令mget 命令get 和 mget 的区别incr 命令incrby 命令decr 命令decrby 命令incrbyfloat 命令append 命令getrange 命令setrange 命令 字符串类型是 Redis 中最基础的数据类型&#xff0c;在讲解命令之前&#xff0c;我们…

STM32基础篇:EXTI × 事件 × EXTI标准库

EXTI EXTI简介 EXTI&#xff1a;译作外部中断/事件控制器&#xff0c;STM32的众多片上外设之一&#xff0c;能够检测外部输入信号的边沿变化并由此产生中断。 例如&#xff0c;在检测按键时&#xff0c;按键按下时会使电平产生翻转&#xff0c;因此可以使用EXTI来读取按下时…

用AirScript脚本给女/男朋友发送每日早安邮件(极简版本)

先看效果 工具 金山文档/WPS提供了每日定时的AirScript脚本服务&#xff0c;非常方便&#xff5e; 话不多说&#xff0c;我们以金山文档为例&#xff0c;只有简单的五个步骤&#xff0c;非常容易&#xff5e; 教程开始 步骤1 我们打开金山文档新建一个智能表格 步骤2 按下图…

基于Python thinker GUI界面的股票评论数据及投资者情绪分析设计与实现

1.绪论 1.1背景介绍 Python 的 Tkinter 库提供了创建用户界面的工具&#xff0c;可以用来构建股票评论数据及投资者情绪分析的图形用户界面&#xff08;GUI&#xff09;。通过该界面&#xff0c;用户可以输入股票评论数据&#xff0c;然后通过情感分析等技术对评论进行情绪分析…

【Linux网络】IP协议{初识/报头/分片/网段划分/子网掩码/私网公网IP/认识网络世界/路由表}

文章目录 1.入门了解2.认识报头3.认识网段4.路由跳转相关指令路由 该文诸多理解参考文章&#xff1a;好文&#xff01; 1.入门了解 用户需求&#xff1a;将我的数据可靠的跨网络从A主机送到B主机 传输层TCP&#xff1a;由各种方法&#xff08;流量控制/超时重传/滑动窗口/拥塞…

【JavaEE】网络编程——TCP

&#x1f921;&#x1f921;&#x1f921;个人主页&#x1f921;&#x1f921;&#x1f921; &#x1f921;&#x1f921;&#x1f921;JavaEE专栏&#x1f921;&#x1f921;&#x1f921; 文章目录 前言1.网络编程套接字1.1流式套接字(TCP)1.1.1特点1.1.2编码1.1.2.1ServerSo…

微信小游戏 彩色试管 倒水游戏 逻辑 (二)

最近开始研究微信小游戏&#xff0c;有兴趣的 可以关注一下 公众号&#xff0c; 记录一些心路历程和源代码。 定义一个 Water class 1. **定义接口和枚举**&#xff1a; - WaterInfo 接口定义了水的颜色、高度等信息。 - PourAction 枚举定义了水的倒动状态&#xff0c;…