深度学习中的并行策略概述:2 Data Parallelism

深度学习中的并行策略概述:2 Data Parallelism
在这里插入图片描述
数据并行(Data Parallelism)的核心在于将模型的数据处理过程并行化。具体来说,面对大规模数据批次时,将其拆分为较小的子批次,并在多个计算设备上同时进行处理。每个设备负责处理一个子批次,实现并行计算。处理完成后,将各个设备上的计算结果汇总,以便对模型进行统一更新。由于其在深度学习中的普遍应用,数据并行成为了一种广泛支持的并行计算策略,并在主流框架中得到了良好的实现。

以下代码展示了如何在PyTorch中使用nn.DataParallel和DistributedDataParallel实现数据并行,以加速模型的训练过程。

使用nn.DataParallel实现数据并行

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader# 假设我们有一个简单的数据集类
class SimpleDataset(Dataset):def __init__(self, data, target):self.data = dataself.target = targetdef __len__(self):return len(self.data)def __getitem__(self, idx):return self.data[idx], self.target[idx]# 假设我们有一个简单的神经网络模型
class SimpleModel(nn.Module):def __init__(self, input_dim):super(SimpleModel, self).__init__()self.fc = nn.Linear(input_dim, 1)def forward(self, x):return torch.sigmoid(self.fc(x))# 假设我们有一些数据
n_sample = 100
n_dim = 10
batch_size = 10
X = torch.randn(n_sample, n_dim)
Y = torch.randint(0, 2, (n_sample,)).float()
dataset = SimpleDataset(X, Y)
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)# 初始化模型
device_ids = [0, 1, 2]  # 指定使用的GPU编号
model = SimpleModel(n_dim).to(device_ids[0])
model = nn.DataParallel(model, device_ids=device_ids)# 定义优化器和损失函数
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
criterion = nn.BCELoss()# 训练模型
for epoch in range(10):for batch_idx, (inputs, targets) in enumerate(data_loader):inputs, targets = inputs.to('cuda'), targets.to('cuda')outputs = model(inputs)loss = criterion(outputs, targets.unsqueeze(1))optimizer.zero_grad()loss.backward()optimizer.step()print(f'Epoch {epoch}, Batch {batch_idx}, Loss: {loss.item()}')

使用DistributedDataParallel实现数据并行

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP# 假设我们有一个简单的数据集类
class SimpleDataset(Dataset):def __init__(self, data, target):self.data = dataself.target = targetdef __len__(self):return len(self.data)def __getitem__(self, idx):return self.data[idx], self.target[idx]# 假设我们有一个简单的神经网络模型
class SimpleModel(nn.Module):def __init__(self, input_dim):super(SimpleModel, self).__init__()self.fc = nn.Linear(input_dim, 1)def forward(self, x):return torch.sigmoid(self.fc(x))# 初始化进程组
def init_process(rank, world_size, backend='nccl'):dist.init_process_group(backend, rank=rank, world_size=world_size)# 训练函数
def train(rank, world_size):init_process(rank, world_size)torch.cuda.set_device(rank)model = SimpleModel(10).to(rank)model = DDP(model, device_ids=[rank])dataset = SimpleDataset(torch.randn(100, 10), torch.randint(0, 2, (100,)).float())sampler = torch.utils.data.distributed.DistributedSampler(dataset, num_replicas=world_size, rank=rank)data_loader = DataLoader(dataset, batch_size=10, sampler=sampler)optimizer = optim.SGD(model.parameters(), lr=0.01)criterion = nn.BCELoss()for epoch in range(10):for inputs, targets in data_loader:inputs, targets = inputs.to(rank), targets.to(rank)optimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs, targets.unsqueeze(1))loss.backward()optimizer.step()if __name__ == "__main__":world_size = 4torch.multiprocessing.spawn(train, args=(world_size,), nprocs=world_size, join=True)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/496523.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

如何快速找到合适的科学问题

前面已经讲过 如何快速判断学术论文质量与相关性 如何描述科学问题?从“术”入手,悟出属于自己的“道” 医学图像分割任务中的典型科学问题 如何快速肝论文? 博士论文的写作架构 这些内容分别阐述了 如何找到重要的相关论文 找到科学问…

如何为运行在 PICO 4 Ultra 设备上的项目设置外部文件读写权限?

PICO 4 Ultra 系列设备使用的安卓操作系统为 Android 14。当项目的 Write Permission 为 Externa (SDCard) 且 Android API Level 大于 32 时,Unity 提供的外部文件读取方式在 PICO 4 Ultra 设备上将失效。此问题提供两种解决方法,按实际情况选取。 解决…

MacOS安装Xcode(非App Store)

文章目录 访问官网资源页面 访问官网资源页面 直接访问官网的历史版本下载资源页面地址:https://developer.apple.com/download/more/完成APP ID的登陆,直接找到需要的软件下载即可 解压后,安装将xcode.app移动到应用程序文件夹。

OpenLinkSaas使用手册-Git工具

在OpenLinkSaas的工具箱里面,最基础的一个就是Git仓库管理。Git仓库功能让git使用更加简单和强大,不仅可以使用常规的commit/pull/push/branch等功能外,还连接了Git仓库供应商的能力。 OpenLinkSass支持使用国内主流的Git仓库供应商的账号登录…

.NET平台用C#通过字节流动态操作Excel文件

在.NET开发中,通过字节流动态操作Excel文件提供了一种高效且灵活的方式处理数据。这种方法允许开发者直接在内存中创建、修改和保存Excel文档,无需依赖直接的文件储存、读取操作,从而提高了程序的性能和安全性。使用流技术处理Excel不仅简化了…

vue之axios基本使用

文章目录 1. axios 网络请求库2. axiosvue 1. axios 网络请求库 <body> <input type"button" value"get请求" class"get"> <input type"button" value"post请求" class"post"> <!-- 官网提供…

STM32开发笔记123:使用FlyMcu下载程序

文章目录 前言一、FlyMcu二、电路图三、使用方法1、配置2、读取器件信息3、擦除芯片4、加载文件下载程序5、启动应用程序前言 本文介绍使用FlyMcu下载程序到STM32微控制器的方法。 一、FlyMcu FlyMcu轻量级,比STM32CubeProgrammer使用更为简便,下载地址:http://www.mcuis…

mysql返回N/A

在写统计图的接口&#xff0c;sql查询一直无数据&#xff0c;给的默认值也没有实现&#xff1a; SELECTifnull( unit.num, 0 ) riskUnitCount,ifnull( EVENT.num, 0 ) riskEventCount,ifnull( measure.num, 0 ) riskMeasureCount FROMtb_companyLEFT JOIN (SELECTrisk.qyid,co…

Linux网络——TCP的运用

系列文章目录 文章目录 系列文章目录一、服务端实现1.1 创建套接字socket1.2 指定网络接口并bind2.3 设置监听状态listen2.4 获取新链接accept2.5 接收数据并处理&#xff08;服务&#xff09;2.6 整体代码 二、客户端实现2.1 创建套接字socket2.2 指定网络接口2.3 发起链接con…

C/C++ 数据结构与算法【哈夫曼树】 哈夫曼树详细解析【日常学习,考研必备】带图+详细代码

哈夫曼树&#xff08;最优二叉树&#xff09; 1&#xff09;基础概念 **路径&#xff1a;**从树中一个结点到另一个结点之间的分支构成这两个结点间的路径。 **结点的路径长度&#xff1a;**两结点间路径上的分支数。 **树的路径长度&#xff1a;**从树根到每一个结点的路径…

Nginx的性能分析与调优简介

Nginx的性能分析与调优简介 一、Nginx的用途二、Nginx负载均衡策略介绍与调优三、其他调优方式简介四、Nginx的性能监控 一、Nginx的用途 ‌Nginx是一种高性能的HTTP和反向代理服务器&#xff0c;最初作为HTTP服务器开发&#xff0c;主要用于服务静态内容如HTML文件、图像、视…

uniapp使用live-pusher实现模拟人脸识别效果

需求&#xff1a; 1、前端实现模拟用户人脸识别&#xff0c;识别成功后抓取视频流或认证的一张静态图给服务端。 2、服务端调用第三方活体认证接口&#xff0c;验证前端传递的人脸是否存在&#xff0c;把认证结果反馈给前端。 3、前端根据服务端返回的状态&#xff0c;显示在…

基于SpringBoot的“房产销售平台”的设计与实现(源码+数据库+文档+PPT)

基于SpringBoot的“房产销售平台”的设计与实现&#xff08;源码数据库文档PPT) 开发语言&#xff1a;Java 数据库&#xff1a;MySQL 技术&#xff1a;SpringBoot 工具&#xff1a;IDEA/Ecilpse、Navicat、Maven 系统展示 系统整体模块图 登录窗口界面 房源信息管理窗口界…

EMC——射频场感应的传导骚扰抗扰度(CS)

术语和定义 AE&#xff08;辅助设备&#xff09; 为受试设备正常运行提供所需信号的设备和检验受试设备性能的设备&#xff1b; 钳注入 是用电缆上的钳合式“电流”注入装置获得的钳注入&#xff1b; 电流钳 由被注入信号的电缆构成的二次绕组实现的电流变换器&#xff1b; 电磁…

探究音频丢字位置和丢字时间对pesq分数的影响

丢字的本质 丢字的本质是在一段音频中一小段数据变为0 丢字对主观感受的影响 1. 丢字位置 丢字的位置对感知效果有很大影响。如果丢字发生在音频信号的静音部分或低能量部分&#xff0c;感知可能不明显&#xff1b;而如果丢字发生在高能量部分或关键音素上&#xff0c;感知…

WordPress网站中如何修复504错误

504网关超时错误是非常常见的一种网站错误。这种错误发生在上游服务器未能在规定时间内完成请求的情况下&#xff0c;对访问者而言&#xff0c;出现504错误无疑会对访问体验大打折扣&#xff0c;从而对网站的转化率和收入造成负面影响。 504错误通常源于服务器端或网站本身的问…

自学记录HarmonyOS Next的HMS AI API 13:语音合成与语音识别

在完成图像处理项目后&#xff0c;我打算研究一下API 13的AI其中的——语音技术。HarmonyOS Next的最新API 13中&#xff0c;HMS AI Text-to-Speech和HMS AI Speech Recognizer提供了语音合成与语音识别的强大能力。 语音技术是现代智能设备的重要组成部分&#xff0c;从语音助…

从百度云网盘下载数据到矩池云网盘或者服务器内

本教程教大家如何快速将百度云网盘数据集或者模型代码文件下载到矩池云网盘或者服务器硬盘上。 本教程使用到了一个开源工具 BaiduPCS-Go&#xff0c;官方地址 &#xff1a; https://github.com/qjfoidnh/BaiduPCS-Go 这个工具可以实现“仿 Linux shell 文件处理命令的百度网…

2024基于大模型的智能运维(附实践资料合集)

基于大模型的智能运维是指利用人工智能技术&#xff0c;特别是大模型技术&#xff0c;来提升IT运维的效率和质量。以下是一些关键点和实践案例&#xff1a; AIOps的发展&#xff1a;AIOps&#xff08;人工智能在IT运维领域的应用&#xff09;通过大数据分析和机器学习技术&…

通过Js动态控制Bootstrap模态框-弹窗效果

目的&#xff1a;实现弹出窗、仅关闭弹窗之后才能操作&#xff08;按ESC可退出&#xff09;。自适应宽度与高度、当文本内容太多时、添加滚动条效果。 效果图 源码 <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8">…