PyTorch 添加 C++ 拓展

参考内容:pytorch添加C++拓展简单实战编写及基本功能测试

文章目录

  • 第一步:编写 C++ 模块
    • test.h
    • test.cpp
  • 第二步:编写 setup.py
  • 第三步:安装 C++ 模块
  • 第四步:验证安装
  • 第五步:C++ 模块使用
    • test_cpp1.py
    • test_cpp2.py
  • 运行结果

编译安装前的文件目录:

这里的 csrc 应该不是指 pytorch 项目中的 /torch/csrc

csrc
├─ cpu
│    ├─ test.cpp
│    └─ test.h
└─ setup.py

第一步:编写 C++ 模块

test.h

#include <torch/extension.h>
#include <vector>// 前向传播
torch::Tensor Test_forward_cpu(const torch::Tensor& inputA, const torch::Tensor& inputB);// 反向传播
std::vector<torch::Tensor> Test_backward_cpu(const torch::Tensor& gradOutput);

test.cpp

#include "test.h"// 前向传播
torch::Tensor Test_forward_cpu(const torch::Tensor& x, const torch::Tensor& y){AT_ASSERTM(x.sizes() == y.sizes(), "x must be the same size as y");torch::Tensor z = torch::zeros(x.sizes());z = 2 * x + y;return z;
}// 反向传播
std::vector<torch::Tensor> Test_backward_cpu(const torch::Tensor& gradOutput){torch::Tensor gradOutputX = 2 * gradOutput * torch::ones(gradOutput.sizes());torch::Tensor gradOutputY = gradOutput * torch::ones(gradOutput.sizes());return {gradOutputX, gradOutputY};
}// pybind11 绑定
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m){m.def("forward", &Test_forward_cpu, "TEST forward");m.def("backward", &Test_backward_cpu, "TEST backward");
}

第二步:编写 setup.py

from setuptools import setup
import os
import glob
from torch.utils.cpp_extension import BuildExtension, CppExtension# 头文件目录
include_dirs = os.path.dirname(os.path.abspath(__file__))
# 源代码目录
source_cpu = glob.glob(os.path.join(include_dirs, 'cpu', '*.cpp'))setup(name='test_cpp', # 模块名称,需要在 python 中调用version="0.1",ext_modules=[CppExtension('test_cpp', sources=source_cpu, include_dirs=[include_dirs]),],cmdclass={'build_ext': BuildExtension}
)

第三步:安装 C++ 模块

在 csrc 文件夹下运行命令

python setup.py install

第一次尝试的报错信息:

/home/zjma/.conda/envs/debugtest/lib/python3.8/site-packages/setuptools/_distutils/cmd.py:66: SetuptoolsDeprecationWarning: setup.py install is deprecated.
!!********************************************************************************Please avoid running ``setup.py`` directly.Instead, use pypa/build, pypa/installer or otherstandards-based tools.See https://blog.ganssle.io/articles/2021/10/setup-py-deprecated.html for details.********************************************************************************!!self.initialize_options()
/home/zjma/.conda/envs/debugtest/lib/python3.8/site-packages/setuptools/_distutils/cmd.py:66: EasyInstallDeprecationWarning: easy_install command is deprecated.
!!********************************************************************************Please avoid running ``setup.py`` and ``easy_install``.Instead, use pypa/build, pypa/installer or otherstandards-based tools.See https://github.com/pypa/setuptools/issues/917 for details.********************************************************************************!!self.initialize_options()

参考 SetuptoolsDeprecationWarning: setup.py install is deprecated. Use build and pip 后得知是 setuptools 版本太高,于是降低 setuptools 版本,pip install setuptools==58.2.0

第二次尝试的运行结果:

running install
running bdist_egg
running egg_info
writing test_cpp.egg-info/PKG-INFO
writing dependency_links to test_cpp.egg-info/dependency_links.txt
writing top-level names to test_cpp.egg-info/top_level.txt
reading manifest file 'test_cpp.egg-info/SOURCES.txt'
writing manifest file 'test_cpp.egg-info/SOURCES.txt'
installing library code to build/bdist.linux-x86_64/egg
running install_lib
running build_ext
building 'test_cpp' extension
creating /home/zjma/pytorch_v1.13.1/csrc/build/temp.linux-x86_64-3.8
creating /home/zjma/pytorch_v1.13.1/csrc/build/temp.linux-x86_64-3.8/home
creating /home/zjma/pytorch_v1.13.1/csrc/build/temp.linux-x86_64-3.8/home/zjma
creating /home/zjma/pytorch_v1.13.1/csrc/build/temp.linux-x86_64-3.8/home/zjma/pytorch_v1.13.1
creating /home/zjma/pytorch_v1.13.1/csrc/build/temp.linux-x86_64-3.8/home/zjma/pytorch_v1.13.1/csrc
creating /home/zjma/pytorch_v1.13.1/csrc/build/temp.linux-x86_64-3.8/home/zjma/pytorch_v1.13.1/csrc/cpu
Emitting ninja build file /home/zjma/pytorch_v1.13.1/csrc/build/temp.linux-x86_64-3.8/build.ninja...
Compiling objects...
Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)
[1/1] c++ -MMD -MF /home/zjma/pytorch_v1.13.1/csrc/build/temp.linux-x86_64-3.8/home/zjma/pytorch_v1.13.1/csrc/cpu/test.o.d -pthread -B /home/zjma/.conda/envs/debugtest/compiler_compat -Wl,--sysroot=/ -Wsign-compare -DNDEBUG -g -fwrapv -O3 -Wall -Wstrict-prototypes -fPIC -I/home/zjma/pytorch_v1.13.1/csrc -I/home/zjma/pytorch_v1.13.1/torch/include -I/home/zjma/pytorch_v1.13.1/torch/include/torch/csrc/api/include -I/home/zjma/pytorch_v1.13.1/torch/include/TH -I/home/zjma/pytorch_v1.13.1/torch/include/THC -I/home/zjma/.conda/envs/debugtest/include/python3.8 -c -c /home/zjma/pytorch_v1.13.1/csrc/cpu/test.cpp -o /home/zjma/pytorch_v1.13.1/csrc/build/temp.linux-x86_64-3.8/home/zjma/pytorch_v1.13.1/csrc/cpu/test.o -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1016"' -DTORCH_EXTENSION_NAME=test_cpp -D_GLIBCXX_USE_CXX11_ABI=1 -std=c++14
cc1plus: warning: command-line option ‘-Wstrict-prototypes’ is valid for C/ObjC but not for C++
creating build/lib.linux-x86_64-3.8
g++ -pthread -shared -B /home/zjma/.conda/envs/debugtest/compiler_compat -L/home/zjma/.conda/envs/debugtest/lib -Wl,-rpath=/home/zjma/.conda/envs/debugtest/lib -Wl,--no-as-needed -Wl,--sysroot=/ /home/zjma/pytorch_v1.13.1/csrc/build/temp.linux-x86_64-3.8/home/zjma/pytorch_v1.13.1/csrc/cpu/test.o -L/home/zjma/pytorch_v1.13.1/torch/lib -lc10 -ltorch -ltorch_cpu -ltorch_python -o build/lib.linux-x86_64-3.8/test_cpp.cpython-38-x86_64-linux-gnu.so
creating build/bdist.linux-x86_64/egg
copying build/lib.linux-x86_64-3.8/test_cpp.cpython-38-x86_64-linux-gnu.so -> build/bdist.linux-x86_64/egg
creating stub loader for test_cpp.cpython-38-x86_64-linux-gnu.so
byte-compiling build/bdist.linux-x86_64/egg/test_cpp.py to test_cpp.cpython-38.pyc
creating build/bdist.linux-x86_64/egg/EGG-INFO
copying test_cpp.egg-info/PKG-INFO -> build/bdist.linux-x86_64/egg/EGG-INFO
copying test_cpp.egg-info/SOURCES.txt -> build/bdist.linux-x86_64/egg/EGG-INFO
copying test_cpp.egg-info/dependency_links.txt -> build/bdist.linux-x86_64/egg/EGG-INFO
copying test_cpp.egg-info/top_level.txt -> build/bdist.linux-x86_64/egg/EGG-INFO
writing build/bdist.linux-x86_64/egg/EGG-INFO/native_libs.txt
zip_safe flag not set; analyzing archive contents...
__pycache__.test_cpp.cpython-38: module references __file__
creating 'dist/test_cpp-0.1-py3.8-linux-x86_64.egg' and adding 'build/bdist.linux-x86_64/egg' to it
removing 'build/bdist.linux-x86_64/egg' (and everything under it)
Processing test_cpp-0.1-py3.8-linux-x86_64.egg
removing '/home/zjma/.conda/envs/debugtest/lib/python3.8/site-packages/test_cpp-0.1-py3.8-linux-x86_64.egg' (and everything under it)
creating /home/zjma/.conda/envs/debugtest/lib/python3.8/site-packages/test_cpp-0.1-py3.8-linux-x86_64.egg
Extracting test_cpp-0.1-py3.8-linux-x86_64.egg to /home/zjma/.conda/envs/debugtest/lib/python3.8/site-packages
test-cpp 0.1 is already the active version in easy-install.pthInstalled /home/zjma/.conda/envs/debugtest/lib/python3.8/site-packages/test_cpp-0.1-py3.8-linux-x86_64.egg
Processing dependencies for test-cpp==0.1
Finished processing dependencies for test-cpp==0.1

编译安装后的文件目录:

csrc
├─ build
│    ├─ bdist.linux-x86_64
│    ├─ lib.linux-x86_64-3.8
│    │    └─ test_cpp.cpython-38-x86_64-linux-gnu.so
│    ├─ lib.linux-x86_64-cpython-38
│    │    └─ test_cpp.cpython-38-x86_64-linux-gnu.so
│    ├─ temp.linux-x86_64-3.8
│    │    ├─ .ninja_deps
│    │    ├─ .ninja_log
│    │    ├─ build.ninja
│    │    └─ home
│    └─ temp.linux-x86_64-cpython-38
│           ├─ .ninja_deps
│           ├─ .ninja_log
│           ├─ build.ninja
│           └─ home
├─ cpu
│    ├─ test.cpp
│    └─ test.h
├─ dist
│    └─ test_cpp-0.1-py3.8-linux-x86_64.egg
├─ setup.py
└─ test_cpp.egg-info├─ PKG-INFO├─ SOURCES.txt├─ dependency_links.txt└─ top_level.txt

第四步:验证安装

1、在虚拟环境的路径 /lib/python3.8/site-packages 下看到 test_cpp-0.1-py3.8-linux-x86_64.egg 文件
在这里插入图片描述
2、conda list 查看当前虚拟环境下已经安装的包
在这里插入图片描述3、进入 python 的交互模式,import test_cpp 后报错:

>>> import test_cpp
Traceback (most recent call last):File "<stdin>", line 1, in <module>
ImportError: libc10.so: cannot open shared object file: No such file or directory

参考 通过Python setup.py install的第三方包,import时却无法导入是什么问题呢? - 神经的网络里挣扎的回答 - 知乎,因为编译的 test_cpp 包需要依赖 torch 包,导致无法导入。所以,在 import test_cpp 前要先 import torch

第五步:C++ 模块使用

test_cpp1.py

import torch
import test_cpp
from torch.autograd import Functionclass TestFunction(Function):@staticmethoddef forward(ctx, x, y):return test_cpp.forward(x, y)@staticmethoddef backward(ctx, gradOutput):gradX, gradY = test_cpp.backward(gradOutput)return gradX, gradYclass Test(torch.nn.Module):def __init__(self):super(Test, self).__init__()def forward(self, inputA, inputB):return TestFunction.apply(inputA, inputB)

test_cpp2.py

import torch
from torch.autograd import Variablefrom test_cpp1 import Testx = Variable(torch.Tensor([1,2,3]), requires_grad=True)
y = Variable(torch.Tensor([4,5,6]), requires_grad=True)test = Test()
z = test(x, y)
z.sum().backward()print('x: ', x)
print('y: ', y)
print('z: ', z)
print('x.grad: ', x.grad)
print('y.grad: ', y.grad)

运行结果

/home/zjma/.conda/envs/debugtest/bin/python /home/zjma/PycharmProjects/pythonProject/test_cpp2.py 
x:  tensor([1., 2., 3.], requires_grad=True)
y:  tensor([4., 5., 6.], requires_grad=True)
z:  tensor([ 6.,  9., 12.], grad_fn=<TestFunctionBackward>)
x.grad:  tensor([2., 2., 2.])
y.grad:  tensor([1., 1., 1.])进程已结束,退出代码为 0

运行结果符合预期。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/243636.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

yolov8的目标检测、实例分割、关节点估计的原理解析

1 YOLO时间线 这里简单列下yolo的发展时间线&#xff0c;对每个版本的提出有个时间概念。 2 yolov8 的简介 工程链接&#xff1a;https://github.com/ultralytics/ultralytics 2.1 yolov8的特点 采用了anchor free方式&#xff0c;去除了先验设置可能不佳带来的影响借鉴General…

跑步运动耳机哪个牌子好?2024年国产运动耳机推荐

​无论春夏秋冬&#xff0c;无论室内还是户外&#xff0c;运动都能带给我们无尽的乐趣。而一副好的运动耳机&#xff0c;更能为我们的运动体验增色不少。今天&#xff0c;就让我为大家推荐几款值得一试的运动耳机吧。 1.南卡开放式耳机&#xff08;00压&#xff09; 一句话评价…

VUE项目目录与运行流程(VScode)

各目录对应名称含义 main.js&#xff08;导入App.vue&#xff0c;基于App.vue创建结构渲染index.html&#xff09; //核心作用&#xff1a;导入App.vue&#xff0c;基于App.vue创建结构渲染index.html//1.导入Vue核心包 import Vue from vue//2.导入App.vue根组件 import App f…

如何在 Ubuntu 22.04 上安装 Linux、Apache、MySQL、PHP (LAMP) 堆栈

前些天发现了一个人工智能学习网站&#xff0c;通俗易懂&#xff0c;风趣幽默&#xff0c;最重要的屌图甚多&#xff0c;忍不住分享一下给大家。点击跳转到网站。 如何在 Ubuntu 22.04 上安装 Linux、Apache、MySQL、PHP (LAMP) 堆栈 介绍 “LAMP”堆栈是一组开源软件&#…

基于ssm框架的网上购物系统-计算机毕业设计源码12503

摘 要 近年来&#xff0c;随着移动互联网的快速发展&#xff0c;电子商务越来越受到网民们的欢迎&#xff0c;电子商务对国家经济的发展也起着越来越重要的作用。简单的流程、便捷可靠的支付方式、快捷畅通的物流快递、安全的信息保护都使得电子商务越来越赢得网民们的青睐。现…

控制项目进展

优质博文 IT-BLOG-CN 假如一个项目准备工作做的非常周详&#xff0c;现在要做的就是监督项目的进展情况&#xff0c;理想状况下事情应当进展的很顺利&#xff0c;但实际上我们会发现项目永远不会完全按照经计划执行&#xff0c;我们必须进行项目控制。也就是我们需要不断进行调…

创建SERVLET

创建SERVLET 要创建servlet,需要执行以下任务: 编写servlet。编译并封装servlet。将servlet部署为Java EE应用程序。通过浏览器访问servlet。编写servlet 要编写servlet,需要扩展HttpServlet接口的类。编写servlet是,需要合并读取客户机请求和返回响应的功能。 读取和处…

安全审查常见要求

一、是否有密码复杂度策略、是否有密码有效期 1&#xff09;密码长度至少8位&#xff1b; 2&#xff09;要求用户密码必须包含大小写字母、数字、特殊字符 3&#xff09;避免常见密码 123456&#xff0c;qwerty, password; 4) 强制用户定期修改密码&#xff1b; 5&#x…

分布式深度学习中的数据并行和模型并行

&#x1f380;个人主页&#xff1a; https://zhangxiaoshu.blog.csdn.net &#x1f4e2;欢迎大家&#xff1a;关注&#x1f50d;点赞&#x1f44d;评论&#x1f4dd;收藏⭐️&#xff0c;如有错误敬请指正! &#x1f495;未来很长&#xff0c;值得我们全力奔赴更美好的生活&…

云风网(www.niech.cn)个人网站搭建(二)服务器域名配置

这里直接采用宝塔服务器运维管理面板来进行配置&#xff0c;简单无脑 宝塔 Linux面板8.0.5安装脚本 //Centos安装脚本 yum install -y wget && wget -O install.sh https://download.bt.cn/install/install_6.0.sh && sh install.sh ed8484bec //Ubuntu/Deepi…

使用双异步后,如何保证数据一致性?

目录 一、前情提要二、通过Future获取异步返回值1、FutureTask 是基于 AbstractQueuedSynchronizer实现的2、FutureTask执行流程3、get()方法执行流程 三、FutureTask源码具体分析1、FutureTask源码2、将异步方法的返回值改为Future<Integer>&#xff0c;将返回值放到new…

【Emotion】 自动驾驶最近面试总结与反思

outline 写在前面面试问题回顾和答案展望 写在前面 最近由于公司部门即将撤销&#xff0c;开始了新一轮准备。 发现现在整体行情不太乐观&#xff0c;很看过去的尤其是量产的经验 同时本次面试我coding环节答得不好&#xff0c;&#xff08;其实也是半年前大家问的比较简单…

新版AndroidStudio dependencyResolutionManagement出错

在新版AndroidStudio中想像使用4.2版本或者4.3版本的AndroidStudio来构造项目&#xff1f;那下面这些坑我们就需要来避免了&#xff0c;否则会出各种各样的问题。 一.我们先来看看新旧两个版本的不同。 1.jdk版本的不同 新版默认是jdk17 旧版默认是jdk8 所以在新版AndroidSt…

javaSSMmysql书籍借阅管理系统04770-计算机毕业设计项目选题推荐(附源码)

摘 要 随着科学技术的告诉发展&#xff0c;我们已经步入数字化、网络化的时代。图书馆是学校的文献信息中心&#xff0c;是为全校教学和科学研究服务的学术性机构&#xff0c;是学校信息化的重要基地。图书馆的工作是学校和科学研究工作的重要组成部分&#xff0c;是全校师生学…

《WebKit 技术内幕》学习之十(1): 插件与JavaScript扩展

虽然目前的浏览器的功能很强 &#xff0c;但仍然有其局限性。早期的浏览器能力十分有限&#xff0c;Web前端开发者希望能够通过一定的机制来扩展浏览器的能力。早期的方法就是插件机制&#xff0c;现在流行次啊用混合编程&#xff08;Hybird Programming&#xff09;模式。插件…

决策树的基本构建流程

决策树的基本构建流程 决策树的本质是挖掘有效的分类规则&#xff0c;然后以树的形式呈现。 这里有两个重点&#xff1a; 有效的分类规则&#xff1b;树的形式。 有效的分类规则&#xff1a;叶子节点纯度越高越好&#xff0c;就像我们分红豆和黄豆一样&#xff0c;我们当然…

表单的总数据为什么可以写成一个空对象,不用具体的写表单中绑定的值,vue3

<el-form :model"form" label-width"120px"><el-form-item label"Activity name"><el-input v-model"form.name" /></el-form-item> </el-form> const form ref({})from为空对象 在v-model里写form…

Python 猎户星空Orion-14B,截止到目前为止,各评测指标均名列前茅,综合指标最强;Orion-14B表现强大,LLMs大模型

1.简介 Orion-14B-Base是一个具有140亿参数的多语种大模型&#xff0c;该模型在一个包含2.5万亿token的多样化数据集上进行了训练&#xff0c;涵盖了中文、英语、日语、韩语等多种语言。在多语言环境下的一系列任务中展现出卓越的性能。在主流的公开基准评测中&#xff0c;Orio…

Tensorflow2.0笔记 - tensor的合并和分割

主要记录concat,stack,unstack和split相关操作的作用 import tensorflow as tf import numpy as nptf.__version__#concat对某个维度进行连接 #假设下面的tensor0和tensor1分别表示4个班级35名同学的8门成绩和两个班级35个同学8门成绩 tensor0 tf.ones([4,35,8]) tensor1 tf…

centos安装:node.js、npm及pm2

前言 Node.js发布于2009年5月&#xff0c;由Ryan Dahl开发&#xff0c;是一个基于Chrome V8引擎的JavaScript运行环境&#xff0c;使用了一个事件驱动、非阻塞式I/O模型&#xff0c;让JavaScript 运行在服务端的开发平台&#xff0c;它让JavaScript成为与PHP、Python、Perl、Ru…