Custom C++ and CUDA Extensions - PyTorch

0. Abstract

经历了一波 pybind11 和 CUDA 编程 的学习, 接下来看一看 PyTorch 官方给的 C++/CUDA 扩展的教程. 发现极其简单, 就是直接用 setuptools 导出 PyTorch C++ 版代码的 Python 接口就可以了. 所以, 本博客包含以下内容:

  • LibTorch 初步;
  • C++ Extension 例子;

1. LibTorch 初步

在 PyTorch 的首页安装指引中就可以看到 PyTorch 是支持 C++/Java 的:

下载后解压到一个地方, 如 /opt/libtorch. 然后就可以使用 C++ 编写 PyTorch 程序了. 官方给的有相关例子, 我们选择最经典的 MNIST 手写数字识别项目来看一看:

mnist/
├── CMakeLists.txt
├── README.md
└── mnist.cpp

1.1 CMake 项目

CMakeLists.txt 是构建 cpp 项目的说明文件:

cmake_minimum_required(VERSION 3.5)
project(mnist)
set(CMAKE_CXX_STANDARD 17)find_package(Torch REQUIRED)option(DOWNLOAD_MNIST "Download the MNIST dataset from the internet" ON)
if (DOWNLOAD_MNIST)message(STATUS "Downloading MNIST dataset")execute_process(COMMAND python ${CMAKE_CURRENT_LIST_DIR}/../tools/download_mnist.py -d ${CMAKE_BINARY_DIR}/dataERROR_VARIABLE DOWNLOAD_ERROR)if (DOWNLOAD_ERROR)message(FATAL_ERROR "Error downloading MNIST dataset: ${DOWNLOAD_ERROR}")endif()
endif()add_executable(mnist mnist.cpp)
target_compile_features(mnist PUBLIC cxx_range_for)
target_link_libraries(mnist ${TORCH_LIBRARIES})

为了下载 MNIST 数据集, 这里用到了一个 Python 文件 ../tools/download_mnist.py, 执行 cmake 后, 编译根目录(build)会出现一个 data 数据文件夹.

  • find_package(Torch REQUIRED) 查找 libtorch 时可能需要指定路径:
    find_package(Torch REQUIRED PATHS "path/to/libtorch/")
  • make 时, Ubuntu18.04 下出现错误: undefined reference to symbol ‘pthread_create@@GLIBC_2.2.5’.
    => 经查阅资料, 说: pthread 不是 linux 下的默认的库, 也就是在链接的时候, 无法找到 phread 库中线程函数的入口地址, 于是链接会失败.
    => 解决方案: target_link_libraries(mnist ${TORCH_LIBRARIES} -lpthread -lm)

make 之后, 执行 ./mnist 就能进行训练与测试了:

CUDA available! Training on GPU.
Train Epoch: 1 [59584/60000] Loss: 0.2078
Test set: Average loss: 0.2062 | Accuracy: 0.935
Train Epoch: 2 [59584/60000] Loss: 0.2039
Test set: Average loss: 0.1304 | Accuracy: 0.959
...

1.2 PyTorch C++ API

接下来看 C++ 代码:

struct Net : torch::nn::Module
{Net() : conv1(torch::nn::Conv2dOptions(1, 10, /*kernel_size=*/5)),conv2(torch::nn::Conv2dOptions(10, 20, /*kernel_size=*/5)),fc1(320, 50),fc2(50, 10){register_module("conv1", conv1);register_module("conv2", conv2);register_module("conv2_drop", conv2_drop);register_module("fc1", fc1);register_module("fc2", fc2);}torch::Tensor forward(torch::Tensor &x){x = torch::relu(torch::max_pool2d(conv1->forward(x), 2));x = torch::relu(torch::max_pool2d(conv2_drop->forward(conv2->forward(x)), 2));x = x.view({-1, 320});x = torch::relu(fc1->forward(x));x = torch::dropout(x, /*p=*/0.5, /*training=*/is_training());x = fc2->forward(x);return torch::log_softmax(x, /*dim=*/1);}torch::nn::Conv2d conv1;torch::nn::Conv2d conv2;torch::nn::Dropout2d conv2_drop;torch::nn::Linear fc1;torch::nn::Linear fc2;
};template<typename DataLoader>
void train(size_t epoch,Net &model,torch::Device device,DataLoader &data_loader,torch::optim::Optimizer &optimizer,size_t dataset_size
)
{model.train();size_t batch_idx = 0;for (auto &batch: data_loader){auto data = batch.data.to(device), targets = batch.target.to(device);auto output = model.forward(data);auto loss = torch::nll_loss(output, targets);AT_ASSERT(!std::isnan(loss.template item<float>()));optimizer.zero_grad();loss.backward();optimizer.step();...}
}

可以看到, 代码非常简单, 几乎和 Python 接口一致, 如果把 :: 换成 ., 就更像了. 不一样的是多了些类型限制以及一些语法. 具体的我们不多研究, 终究还是没有 Python 简洁好用. 但简单了解一下 PyTorch C++ API 的文档说明还是有必要的:

所以, 这个 LibTorch 既能用来写 C++ 项目, 也能用来给 PyTorch 写扩展. 不过官方还是推荐使用 Python 接口:

2. C++ Extension 例子

官方文档给的例子比较复杂, 这里举一个简单的例子, 把计算:

y = torch.relu(torch.matmul(x, w.t()) + b)

整合到一个操作里, 也就是使用 LibTorch C++ 编写一个等价的运算, 并导出 Python 接口. 这么做的理由是:

大概意思就是 Python 比较慢, 由 Python 一次次调用操作而频繁启动 CUDA 核会拖慢速度.

其实我觉得只有用 CUDA 编程把序列操作整合起来才能真正减少 CUDA 核的频繁启动, LibTorch 能加速可能就是因为 C++ 更快而已.

直接上代码吧, 整个项目的解构是这样子的:

LinearAct/
├── linearfun.py
├── linearact.cpp
└── setup.py

linearact.cpp 包含了组合操作的 forward 过程和 backward 过程, 前者计算正向的正常计算, 后者计算反向的梯度计算:

#include <torch/extension.h>  // 注意这里头文件和直接写 C++ 项目不一样
#include <vector>std::vector<at::Tensor> forward(torch::Tensor &input, torch::Tensor &weight, torch::Tensor &bias)
{auto relu_input = input.mm(weight.t()) + bias;auto output = torch::relu(relu_input);return {relu_input, output};  // relu_input 会在梯度计算时用到
}std::vector<torch::Tensor>
backward(torch::Tensor &grad_output, torch::Tensor &relu_input, torch::Tensor &input, torch::Tensor &weight)
{   // 求导链式法则auto grad_relu = grad_output.masked_fill(relu_input < 0, 0);auto grad_input = grad_relu.mm(weight);auto grad_weight = grad_relu.t().mm(input);auto grad_bias = grad_relu.sum(0);return {grad_input, grad_weight, grad_bias};
}PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {m.def("forward", &forward, "Custom forward");m.def("backward", &backward, "Custom backward");
}

这涉及到 pybind11 的用法, 详情见《pybind11 学习笔记》, 还涉及到使用 torch.autograd.Function 自定义运算的梯度计算, 详情见《PyTorch 中的 apply [autograd.Function]》. 总之, 现在我们使用 LibTorch 写了组合操作, 并写了其参数的梯度计算. linearfun.py 是利用 torch.autograd.Functionforwardbackward 整合到一起, 组成一个完整的可以进行反向梯度传播的组合运算:

import torch  # 注意, 导入 linearact 前, 应先导入 torch
import linearactclass LinearActFunction(torch.autograd.Function):@staticmethoddef forward(ctx, input, weights, bias):relu_input, output = linearact.forward(input, weights, bias)  # c++ 函数variables = [relu_input, input, weights]ctx.save_for_backward(*variables)return output@staticmethoddef backward(ctx, grad_output):outputs = linearact.backward(grad_output, *ctx.saved_tensors)  # c++ 函数grad_x, grad_w, grad_b = outputsreturn grad_x, grad_w, grad_bmylinear = LinearActFunction.apply

LibTorch C++ 代码由 setuptools 导出 Python 接口:

from setuptools import setup
from torch.utils import cpp_extensionsetup(name='linearact',ext_modules=[cpp_extension.CppExtension('linearact', ['linearact.cpp'])],cmdclass={'build_ext': cpp_extension.BuildExtension}  # 整合了 pybind11 的功能
)

在命令行执行:

python setup.py install

就可以将 linearact 包安装到 Python 系统中, 任务完成. 下面进行验证:

import torch
from linearfun import mylinearx = torch.randn(2, 3, requires_grad=True)
w = torch.randn(2, 3, requires_grad=True)
b = torch.randn(2, requires_grad=True)
# 复制一份一样的参数
x1 = torch.from_numpy(x.detach().numpy())
w1 = torch.from_numpy(w.detach().numpy())
b1 = torch.from_numpy(b.detach().numpy())
x1.requires_grad_(True)
w1.requires_grad_(True)
b1.requires_grad_(True)# %% pytorch
y = torch.relu(torch.matmul(x, w.t()) + b)
y = y.norm(p=2)
print(y)y.backward()
print(x.grad)
print(w.grad)
print(b.grad)# %% custom
print('---------------------------')
y = mylinear(x1, w1, b1)
y = y.norm(p=2)
print(y)y.backward()
print(x1.grad)
print(w1.grad)
print(b1.grad)

执行一次:

tensor(1.2664, grad_fn=<LinalgVectorNormBackward0>)
tensor([[ 0.0851, -1.0418,  0.3958],[ 0.0566, -0.6925,  0.2631]])
tensor([[ 0.0000,  0.0000,  0.0000],[-1.0724,  0.3669, -0.1399]])
tensor([0.0000, 1.3864])
---------------------------
tensor(1.2664, grad_fn=<LinalgVectorNormBackward0>)
tensor([[ 0.0851, -1.0418,  0.3958],[ 0.0566, -0.6925,  0.2631]])
tensor([[ 0.0000,  0.0000,  0.0000],[-1.0724,  0.3669, -0.1399]])
tensor([0.0000, 1.3864])

可以看见两者一模一样. 至于测速什么的不在本博文的考虑范围之内, 只是想了解 PyTorch 如何进行 C++ 扩展.

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/439397.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

国庆刷题(day4)

C语言&#xff1a; C&#xff1a;

gdb 调试 linux 应用程序的技巧介绍

使用 gdb 来调试 Linux 应用程序时&#xff0c;可以显著提高开发和调试的效率。gdb&#xff08;GNU 调试器&#xff09;是一款功能强大的调试工具&#xff0c;适用于调试各类 C、C 程序。它允许我们在运行程序时检查其状态&#xff0c;设置断点&#xff0c;跟踪变量值的变化&am…

电脑手机下载小米xiaomi redmi刷机包太慢 解决办法

文章目录 修改前下载速度修改后下载速度修改方法&#xff08;修改host&#xff09; 修改前下载速度 一开始笔者以为是迅雷没开会员的问题&#xff0c;在淘宝上买了一个临时会员后下载速度依然最高才100KB/s 修改后下载速度 修改方法&#xff08;修改host&#xff09; host文…

Python编码规范与常见问题纠正

Python编码规范与常见问题纠正 Python 是一种以简洁和易读性著称的编程语言&#xff0c;因此&#xff0c;遵循良好的编码规范不仅能使代码易于维护&#xff0c;还能提升代码的可读性和可扩展性。编写规范的 Python 代码也是开发者职业素养的一部分&#xff0c;本文将从 Python…

TryHackMe 第6天 | Web Fundamentals (一)

这一部分我们要简要介绍以下 Web Hacking 的基本内容&#xff0c;预计分三次博客。 在访问 Web 应用时&#xff0c;浏览器提供了若干个工具来帮助我们发现一些潜在问题和有用的信息。 比如可以查看网站源代码。查看源代码可以 右键 网页&#xff0c;然后选择 查看网站源代码&…

【复习】CSS中的选择器

文章目录 东西有点多 以实战为主选择器盒子模型 东西有点多 以实战为主 选择器 CSS选择器&#xff08;CSS Selectors&#xff09;是用于在HTML或XML文档中查找和选择元素&#xff0c;以便应用CSS样式的一种方式。 元素选择器&#xff08;Type Selector&#xff09; 选择所有…

探索 aMQTT:Python中的AI驱动MQTT库

文章目录 探索 aMQTT&#xff1a;Python中的AI驱动MQTT库背景介绍aMQTT是什么&#xff1f;如何安装aMQTT&#xff1f;简单库函数使用方法场景应用常见问题及解决方案总结 探索 aMQTT&#xff1a;Python中的AI驱动MQTT库 背景介绍 在物联网和微服务架构的浪潮中&#xff0c;MQ…

CSS3练习--电商web

免责声明&#xff1a;本文仅做分享&#xff01; 目录 小练--小兔鲜儿 目录构建 SEO 三大标签 Favicon 图标 布局网页 版心 快捷导航&#xff08;shortcut&#xff09; 头部&#xff08;header&#xff09; logo 导航 搜索 购物车 底部&#xff08;footer&#xff0…

2024年计算机视觉与艺术研讨会(CVA 2024)

目录 基本信息 大会简介 征稿主题 会议议程 参会方式 基本信息 大会官网&#xff1a;www.icadi.net&#xff08;点击了解参会投稿等信息&#xff09; 大会时间&#xff1a;2024年11月29-12月1日 大会地点&#xff1a;中国-天津 大会简介 2024年计算机视觉与艺术国际学术…

Redis --- 第三讲 --- 通用命令

一、get和set命令 Redis中最核心的两个命令 get 根据key来取value set 把key和value存储进去 redis是按照键值对的方式存储数据的。必须要先进入到redis客户端。 语法 set key value &#xff1a; key和value都是字符串。 对于上述这里的key value 不需要加上引号&#…

数据库概述(1)

课程主页&#xff1a;Guoliang Li Tsinghua 数据库在计算机系统中的位置 首先&#xff0c;数据库是在设计有大量数据存储需求的软件时必不可少可的基础。 最常见的是&#xff1a;我们通过app或者是浏览器来实现一些特定需求——比如转账、订车票。即引出背后的CS和BS两种网…

重学SpringBoot3-集成Redis(三)

更多SpringBoot3内容请关注我的专栏&#xff1a;《SpringBoot3》 期待您的点赞&#x1f44d;收藏⭐评论✍ 重学SpringBoot3-集成Redis&#xff08;三&#xff09; 1. 引入 Redis 依赖2. 配置 RedisCacheManager 及自定义过期策略2.1 示例代码&#xff1a;自定义过期策略 3. 配置…

如何使用ssm实现民族大学创新学分管理系统分析与设计+vue

TOC ssm763民族大学创新学分管理系统分析与设计vue 第1章 绪论 1.1 课题背景 二十一世纪互联网的出现&#xff0c;改变了几千年以来人们的生活&#xff0c;不仅仅是生活物资的丰富&#xff0c;还有精神层次的丰富。在互联网诞生之前&#xff0c;地域位置往往是人们思想上不…

【rCore OS 开源操作系统】Rust 字符串(可变字符串String与字符串切片str)

【rCore OS 开源操作系统】Rust 语法详解: Strings 前言 这次涉及到的题目相对来说比较有深度&#xff0c;涉及到 Rust 新手们容易困惑的点。 这一次在直接开始做题之前&#xff0c;先来学习下字符串相关的知识。 Rust 的字符串 Rust中“字符串”这个概念涉及多种类型&…

【EXCEL数据处理】000017 案例 Match和Index函数。

前言&#xff1a;哈喽&#xff0c;大家好&#xff0c;今天给大家分享一篇文章&#xff01;创作不易&#xff0c;如果能帮助到大家或者给大家一些灵感和启发&#xff0c;欢迎收藏关注哦 &#x1f495; 目录 【EXCEL数据处理】000016 案例 Match和Index函数。使用的软件&#xff…

DenseNet算法:口腔癌识别

本文为为&#x1f517;365天深度学习训练营内部文章 原作者&#xff1a;K同学啊 一 DenseNet算法结构 其基本思路与ResNet一致&#xff0c;但是它建立的是前面所有层和后面层的密集连接&#xff0c;它的另一大特色是通过特征在channel上的连接来实现特征重用。 二 设计理念 三…

成都跃享未来教育咨询有限公司抖音小店:引领教育咨询新风尚

在数字化浪潮席卷全球的今天&#xff0c;教育咨询行业正经历着前所未有的变革。成都跃享未来教育咨询有限公司&#xff0c;作为教育行业的一颗璀璨新星&#xff0c;凭借其前瞻性的教育理念与创新的运营模式&#xff0c;在抖音平台上开设了小店&#xff0c;不仅为广大学子及家长…

学籍管理平台|在线学籍管理平台系统|基于Springboot+VUE的在线学籍管理平台系统设计与实现(源码+数据库+文档)

在线学籍管理平台系统 目录 基于SpringbootVUE的在线学籍管理平台系统设计与实现 一、前言 二、系统功能设计 三、系统实现 四、数据库设计 1、实体ER图 五、核心代码 六、论文参考 七、最新计算机毕设选题推荐 八、源码获取&#xff1a; 博主介绍&#xff1a;✌️大…

C++多态、虚函数以及抽象类

目录 1.多态的概念 2.多态的定义及实现 2.1多态的构成条件 2.1.1实现多态还有两个必要条件 2.1.2虚函数 2.1.3虚函数的重写/覆盖 2.1.4多态场景的题目 2.1.5虚函数重写的一些其他问题 2.1.5.1协变(了解) 2.1.5.2析构函数的重写 2.1.6override和final关键字 2.…

【Springer上传手稿记录】《Signal, Image and Video Processing》

Springer上传手稿记录 以signal&#xff0c;image and video proecessing为例上传手稿或图片时提示上传失败无法编译成pdf错误1&#xff1a;出现Unknown theoremstyle相关错误错误2&#xff1a;command Illegal错误3&#xff1a;command I found no style file 通用问题问题1&a…