PyTorch 中 reciprocal(取倒数)函数的深入解析:分析底层实现CPP代码

PyTorch 中 reciprocal 函数的深入解析

reciprocal: 美 [rɪˈsɪprəkl] [数]倒数; 注意发音

引言

reciprocal 是 PyTorch 和底层 C++ 实现中广泛使用的数学函数,它计算输入的倒数(reciprocal)。倒数在数值计算、反向传播和优化过程中经常使用,尤其是在浮点数缩放和归一化的场景中。本文将从 PyTorch 的 Python 接口出发,逐步深入分析其底层 C++ 实现,帮助读者全面理解 reciprocal 的高效性和适用场景。


1. reciprocal 的基本功能

在 PyTorch 中,reciprocal 用于计算输入张量的倒数。基本用法如下:

import torch
x = torch.tensor([2.0, 4.0, 8.0])
reciprocal_x = x.reciprocal()
print(reciprocal_x)

输出:

tensor([0.5000, 0.2500, 0.1250])

该函数对输入张量逐元素操作,返回每个元素的倒数。

1.1 注意事项

  • 浮点精度问题:由于浮点数表示有限精度,计算结果可能存在细微误差。
  • 零除问题:输入包含零时会产生无穷值(inf)或 NaN,但不会报错。
x = torch.tensor([0.0, 1.0, 2.0])
reciprocal_x = x.reciprocal()
print(reciprocal_x)

输出:

tensor([   inf, 1.0000, 0.5000])

2. 底层 C++ 实现分析

PyTorch 的 reciprocal 函数在底层通过 C++ 实现,针对不同的数据类型和平台进行了优化。以下是关键代码片段:

2.1 标量和向量操作

底层定义的通用函数:

Vectorized<T> reciprocal() const {return map([](T x) { return (T)(1) / x; });
}

这里利用 map 函数实现逐元素操作,将每个元素的倒数映射到新数组。

2.2 特定类型优化

1. 单精度浮点数 (float)
Vectorized<float> reciprocal() const {return Vectorized<float>(vdivq_f32(vdupq_n_f32(1.0f), values));
}

解释

  • vdupq_n_f32(1.0f):将常数 1.0f 广播到所有向量元素。
  • vdivq_f32:利用 NEON 指令集(ARM 架构)实现向量化除法操作。
  • 优势:避免逐元素循环,提高 SIMD(单指令多数据)并行处理速度。
2. 双精度浮点数 (double)
Vectorized<double> reciprocal() const {return svdivr_f64_x(ptrue, values, ONE_F64);
}

解释

  • 使用 ARM SVE(Scalable Vector Extension)指令优化双精度操作。
  • svdivr_f64_x:高效并行除法操作。
  • 优势:适合高性能计算,特别是在多核 CPU 或 GPU 环境下。
3. 复数类型 (Complex)

复数倒数的计算逻辑:

Vectorized<ComplexDbl> reciprocal() const {auto c_d = *this ^ vd_isign_mask; // 取共轭auto abs = abs_2_();return c_d.elwise_div(abs);
}

解释

  • 共轭计算:复数倒数公式依赖于共轭复数。
  • 平方和归一化:计算分母的平方和避免直接除法误差。
  • 逐元素除法:高效实现复数除法操作。

3. PyTorch AMP (自动混合精度) 中的应用

在 PyTorch 中,reciprocal 经常与自动混合精度训练(AMP)结合使用。例如:

scaler = torch.cuda.amp.GradScaler()
inv_scale = scaler.get_scale().double().reciprocal().float()

3.1 动机

  • 防止梯度溢出:在反向传播中,缩放梯度以保持数值稳定性。
  • 高精度计算:避免 FP32 精度不够的问题,通过 FP64 进行关键计算。

3.2 示例代码

from torch.cuda.amp import autocast, GradScaler
scaler = GradScaler()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)for inputs, labels in dataloader:with autocast():outputs = model(inputs)loss = loss_fn(outputs, labels)scaler.scale(loss).backward()scaler.step(optimizer)scaler.update()

在更新过程中,会计算倒数缩放因子,确保数值计算安全。


4. 性能测试与比较

测试环境:

  • GPU: NVIDIA A100
  • PyTorch 版本: 2.0.1
  • 数据集: 随机生成 1,000,000 个浮点数
import torch
torch.manual_seed(0)x = torch.rand(1000000, device='cuda')# 方法1: 原生逐元素倒数
%timeit 1 / x# 方法2: PyTorch reciprocal
%timeit x.reciprocal()

结果示例

1 / x:  3.25 ms ± 0.02 ms per loop
x.reciprocal():  1.04 ms ± 0.01 ms per loop

分析

  • reciprocal 函数利用底层 SIMD 优化,比逐元素除法快约 3倍。这里笔者没测算过,这是GPT4o给出的数据。真实性待核查。
  • 支持 CUDA 加速,可直接在 GPU 上并行计算。

5. 总结

本文详细解析了 PyTorch 中 reciprocal 函数的基本用法、底层 C++ 实现以及其在 AMP 训练中的应用。

关键要点

  1. reciprocal 是计算倒数的高效函数,适用于数值计算和深度学习。
  2. 底层实现利用 SIMD 和 SVE 指令集,针对不同数据类型优化。
  3. 在 AMP 环境中,通过 FP64 确保缩放精度,提升数值稳定性。
  4. 性能测试显示 reciprocal 的速度远快于传统逐元素除法。

通过本文的分析,希望读者能够更深入理解 PyTorch 底层实现和优化策略,并灵活运用 reciprocal 处理复杂计算任务。

后记

2025年1月2日20点19分于上海, 在GPT4o大模型辅助下完成。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/500761.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

天猫推荐数据集实践

参考自 https://github.com/xufengtt/recom_teach_code&#xff0c;学习记录。 环境配置&#xff08;maxcomputedataworks&#xff09; 下载天猫推荐数据集&#xff1b;开启 aliyun 的 maxcompute&#xff0c;dataworks&#xff0c;pai&#xff1b;使用 odpscmd 上传本地数据…

人脑处理信息的速度与效率:超越计算机的直观判断能力

人脑处理信息的速度与效率&#xff1a;超越计算机的直观判断能力 关键词&#xff1a; #人脑信息处理 Human Brain Information Processing #并行处理 Parallel Processing #视觉信息分析 Visual Information Analysis #决策速度 Decision Speed #计算机与人脑比较 Computer v…

checked 溢出问题

{try{int i int.MaxValue;int j;checked{j i 1;}}catch (OverflowException er){Console.WriteLine($"加Checked——>{er.Message}");}}{try{int i int.MaxValue;int j;j i 1;}catch (OverflowException er){Console.WriteLine($"没有加Checked——&g…

LabVIEW 使用 Resample Waveforms VI 实现降采样

在数据采集与信号处理过程中&#xff0c;降采样是一种重要的技术&#xff0c;用于在减少数据点的同时保留信号的关键特性&#xff0c;从而降低存储和计算需求。本文通过 LabVIEW 的 Resample Waveforms (continuous).vi 示例&#xff0c;详细介绍如何使用该功能实现波形数据的降…

数字化供应链创新解决方案在零售行业的应用研究——以开源AI智能名片S2B2C商城小程序为例

摘要&#xff1a; 在数字化转型的浪潮中&#xff0c;零售行业正经历着前所未有的变革。特别是在供应链管理方面&#xff0c;线上线下融合、数据孤岛、消费者需求多样化等问题日益凸显&#xff0c;对零售企业的运营效率与市场竞争力构成了严峻挑战。本文深入探讨了零售行业供应…

《计算机网络》(B)复习

目录 一、问答题测试 1.论述具有五层协议的网络体系结构的要点&#xff0c;包括各层的主要功能。 2.物理层的接口有哪几个方面的特性&#xff1f;各包含些什么内容&#xff1f; 3.小明想要访问淘宝&#xff0c;当他打开浏览器输入www.taobao.com浏览淘宝的 过程是什么&#…

用Tkinter制作一个用于合并PDF文件的小程序

需要安装PyPDF2库&#xff0c;具体原代码如下&#xff1a; # -*- coding: utf-8 -*- """ Created on Sun Dec 29 14:44:20 2024author: YBK """import PyPDF2 import os import tkinter as tk import windndpdf_files [] def dragged_files(f…

“大数据+职业本科”:VR虚拟仿真实训室的发展前景

在新时代背景下&#xff0c;随着科技的飞速进步和产业结构的不断升级&#xff0c;职业教育正迎来前所未有的变革。“大数据职业本科”的新型教育模式&#xff0c;结合VR&#xff08;虚拟现实&#xff09;技术的广泛应用&#xff0c;为实训教学开辟了崭新的道路&#xff0c;尤其…

【异常解决】生产环境 net :: ERR_INCOMPLETE_CHUNKED_ENCODING的问题修复

博主介绍&#xff1a;✌全网粉丝22W&#xff0c;CSDN博客专家、Java领域优质创作者&#xff0c;掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java技术领域✌ 技术范围&#xff1a;SpringBoot、SpringCloud、Vue、SSM、HTML、Nodejs、Python、MySQL、PostgreSQL、大数据、物…

矩阵运算提速——玩转opencv::Mat

介绍:用Eigen或opencv::Mat进行矩阵的运算&#xff0c;比用cpp的vector或vector进行矩阵运算要快吗? 使用 Eigen 或 OpenCV 的 cv::Mat 进行矩阵运算通常比使用 std::vector<int> 或 std::vector<double> 更快。这主要有以下几个原因&#xff1a; 优化的底层实现…

mac m2 安装 docker

文章目录 安装1.下载安装包2.在downloads中打开3.在启动台打开打开终端验证 修改国内镜像地址小结 安装 1.下载安装包 到官网下载适配的安装包&#xff1a;https://www.docker.com/products/docker-desktop/ 2.在downloads中打开 拖过去 3.在启动台打开 选择推荐设置 …

redis的集群模式与ELK基础

一、redis的集群模式 1.主从复制 &#xff08;1&#xff09;概述 主从模式&#xff1a;这是redis高可用的基础&#xff0c;哨兵和集群都是建立在此基础之上。 主从模式和数据库的主从模式是一样的&#xff0c;主负责写入&#xff0c;然后把写入的数据同步到从服务器&#xff…

建立一个Macos载入image的实例含界面

前言 为了方便ios程序的开发&#xff0c;有时候需要先用的Macos平台进行一些功能性的程序开发。 作为对比和参考。 1、创建一个MacOS的App 2、主界面控件的增加 添加的控件方法与ios相同&#xff0c;也是再用commandshiftL&#xff08;CtrlShiftL&#xff09;,就会弹出控件…

《机器学习》从入门到实战——逻辑回归

目录 一、简介 二、逻辑回归的原理 1、线性回归部分 2、逻辑函数&#xff08;Sigmoid函数&#xff09; 3、分类决策 4、转换为概率的形式使用似然函数求解 5、对数似然函数 ​编辑 6、转换为梯度下降任务 三、逻辑回归拓展知识 1、数据标准化 &#xff08;1&#xf…

实践:事件循环

实践&#xff1a;事件循环 代码示例 console.log(1); setTimeout(() > console.log(2), 0); Promise.resolve(3).then(res > console.log(res)); console.log(4);上述的代码的输出结果是什么 1和4肯定优先输出&#xff0c;因为他们会立即方式堆栈的执行上下文中执行&am…

【机器学习】工业 4.0 下机器学习如何驱动智能制造升级

我的个人主页 我的领域&#xff1a;人工智能篇&#xff0c;希望能帮助到大家&#xff01;&#xff01;&#xff01;&#x1f44d;点赞 收藏❤ 随着科技的飞速发展&#xff0c;工业 4.0 浪潮正席卷全球制造业&#xff0c;而机器学习作为这一变革中的关键技术&#xff0c;正以前…

自从学会Git,感觉打开了一扇新大门

“同事让我用 Git 提交代码&#xff0c;我居然直接把项目文件压缩发过去了……”相信很多初学者都经历过类似的窘境。而当你真正掌握 Git 时&#xff0c;才会发现它就像一本魔法书&#xff0c;轻松解决代码管理的种种难题。 为什么 Git 能成为程序员的标配工具&#xff1f;它究…

Mono里运行C#脚本21—mono_image_init_name_cache

前面分析了怎么样加载mscorlib.dll文件,然后把文件数据读取到内存。 接着下来,就会遇到加载整个C#的类型系统,比如System. Object,大体类型如下图所示: 在对CIL编译之前,需要把这些类型全部加载到内存里,以便快捷地访问它们。 mono_image_init_name_cache函数就是完成…

【Triton-ONNX】如何使用 ONNX 模型服务与 Triton 通信执行推理任务上-Triton快速开始

模型部署系列文章 前置-docker 理解:【 0 基础 Docker 极速入门】镜像、容器、常用命令总结前置-http/gRPC 的理解: 【HTTP和gRPC的区别】协议类型/传输效率 /性能等对比【保姆级教程附代码】Pytorch (.pth) 到 TensorRT (.plan) 模型转化全流程【保姆级教程附代码(二)】Pytor…

win32汇编环境,对话框中显示bmp图像文件

;运行效果 ;win32汇编环境&#xff0c;对话框中显示bmp图像文件 ;显示的是一张尺寸267*400的bmp位图,及一张缩小为原来三分之一的位图 ;将代码复制进radasm软件里&#xff0c;直接编译就可以运行了 ;下面为asm文件 ;>>>>>>>>>>>>>>&…