【Pytorch】(十三)模型部署: TorchScript

文章目录

  • (十三)模型部署: TorchScript
    • Pytorch动态图的优缺点
    • TorchScript
    • Pytorch模型转换为TorchScript
      • torch.jit.trace
      • torch.jit.script
      • trace和script的区别总结
      • trace 和script 混合使用
      • 保存和加载模型

(十三)模型部署: TorchScript

Pytorch动态图的优缺点

与Tensorflow使用静态计算图不同,PyTorch 使用的是动态计算图:

动态图允许在运行时渐进地构建计算图,使得模型设计更加灵活。开发者可以使用 Python 的控制流结构(如循环、条件语句等)来动态地定义模型的结构,从而更容易实现复杂的模型逻辑。

这种计算方式更直观,更pythonic。开发者可以更容易地理解和调试模型各个模块,快速地修改、迭代模型。

然而,与静态图相比,动态图的执行效率可能会较低。因为动态图难以进行一些计算图的优化,如运算符融合、图优化等。而且,动态图依赖于Python 环境。这些因素使得动态图不适合在低延迟要求较高的生产环境下部署。

因此,在部署Pytorch训练后的模型时,需要将动态图转换为静态图,这就要用到TorchScript。

TorchScript

TorchScript是PyTorch模型的一种静态图表示形式,支持模型的部署优化、跨平台部署以及与其他深度学习框架的集成:

  • 模型的部署优化:TorchScript 可以帮助优化 PyTorch 模型以提高性能和效率。通过将模型转换为静态图形式,TorchScript 可以应用各种优化技术,如运算符融合、图优化等,从而加速模型执行并降低内存消耗。
  • 跨平台部署:将模型转换为 TorchScript 格式可以实现跨平台部署,模型可以在没有 Python 环境的情况下运行。这对于在生产环境中部署模型到服务器、移动设备或边缘设备上非常有用。
  • 与其他框架集成:通过将 PyTorch 模型转换为 TorchScript 格式,可以更方便地与其他深度学习框架进行交互。例如,可以将TorchScript 进一步转换为 ONNX 格式,从而与 TensorFlow 等其他框架进行集成和交互操作。

Pytorch模型转换为TorchScript

torch.jit.tracetorch.jit.script 是 PyTorch 中用于模型转换为 TorchScript 格式的工具,但它们有不同的作用和使用场景。

torch.jit.trace

通过torch.jit.trace 将 没有控制流的MyCell 模块转化为TorchScript:


import torch  # This is all you need to use both PyTorch and TorchScript!torch.manual_seed(191009)  # set the seed for reproducibilityclass MyCell(torch.nn.Module):def __init__(self):super(MyCell, self).__init__()self.linear = torch.nn.Linear(4, 4)def forward(self, x, h):new_h = torch.tanh(self.linear(x) + h)return new_h, new_hmy_cell = MyCell()
x, h = torch.rand(3, 4), torch.rand(3, 4)
traced_cell = torch.jit.trace(my_cell, (x, h))
print(traced_cell)
MyCell(original_name=MyCell(linear): Linear(original_name=Linear)
)

torch.jit.trace调用了my_cell,记录了模块计算时发生的操作,并创建了一个torch.jit.ScriptModule的实例(TracedModule是其实例)traced_celltraced_cell 记录了my_cell的计算图。我们可以使用.graph属性来查看:

print(traced_cell.graph)
graph(%self.1 : __torch__.MyCell,%x : Float(3, 4, strides=[4, 1], requires_grad=0, device=cpu),%h : Float(3, 4, strides=[4, 1], requires_grad=0, device=cpu)):%linear : __torch__.torch.nn.modules.linear.Linear = prim::GetAttr[name="linear"](%self.1)%20 : Tensor = prim::CallMethod[name="forward"](%linear, %x)%11 : int = prim::Constant[value=1]() # /var/lib/workspace/beginner_source/Intro_to_TorchScript_tutorial.py:189:0%12 : Float(3, 4, strides=[4, 1], requires_grad=1, device=cpu) = aten::add(%20, %h, %11) # /var/lib/workspace/beginner_source/Intro_to_TorchScript_tutorial.py:189:0%13 : Float(3, 4, strides=[4, 1], requires_grad=1, device=cpu) = aten::tanh(%12) # /var/lib/workspace/beginner_source/Intro_to_TorchScript_tutorial.py:189:0%14 : (Float(3, 4, strides=[4, 1], requires_grad=1, device=cpu), Float(3, 4, strides=[4, 1], requires_grad=1, device=cpu)) = prim::TupleConstruct(%13, %13)return (%14)

然而,图中包含的大多数信息对我们没有用处。我们可以使用.code属性对其进行Python语法解释:

print(traced_cell.code)
def forward(self,x: Tensor,h: Tensor) -> Tuple[Tensor, Tensor]:linear = self.linear_0 = torch.tanh(torch.add((linear).forward(x, ), h))return (_0, _0)

调用traced_cell会产生与Python模块实例my_cell() 相同的结果:

print(my_cell(x, h))
print(traced_cell(x, h))
(tensor([[-0.2541,  0.2460,  0.2297,  0.1014],[-0.2329, -0.2911,  0.5641,  0.5015],[ 0.1688,  0.2252,  0.7251,  0.2530]], grad_fn=<TanhBackward0>), tensor([[-0.2541,  0.2460,  0.2297,  0.1014],[-0.2329, -0.2911,  0.5641,  0.5015],[ 0.1688,  0.2252,  0.7251,  0.2530]], grad_fn=<TanhBackward0>))
(tensor([[-0.2541,  0.2460,  0.2297,  0.1014],[-0.2329, -0.2911,  0.5641,  0.5015],[ 0.1688,  0.2252,  0.7251,  0.2530]], grad_fn=<TanhBackward0>), tensor([[-0.2541,  0.2460,  0.2297,  0.1014],[-0.2329, -0.2911,  0.5641,  0.5015],[ 0.1688,  0.2252,  0.7251,  0.2530]], grad_fn=<TanhBackward0>))

torch.jit.script

我们先尝试通过torch.jit.trace 将 带有控制流的MyCell 模块转化为TorchScript:

class MyDecisionGate(torch.nn.Module):def forward(self, x):if x.sum() > 0:return xelse:return -xclass MyCell(torch.nn.Module):def __init__(self, dg):super(MyCell, self).__init__()self.dg = dgself.linear = torch.nn.Linear(4, 4)def forward(self, x, h):new_h = torch.tanh(self.dg(self.linear(x)) + h)return new_h, new_hmy_cell = MyCell(MyDecisionGate())
traced_cell = torch.jit.trace(my_cell, (x, h))print(traced_cell.dg.code)
print(traced_cell.code)
/var/lib/workspace/beginner_source/Intro_to_TorchScript_tutorial.py:261: TracerWarning:Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!def forward(self,argument_1: Tensor) -> NoneType:return Nonedef forward(self,x: Tensor,h: Tensor) -> Tuple[Tensor, Tensor]:dg = self.dglinear = self.linear_0 = (linear).forward(x, )_1 = (dg).forward(_0, )_2 = torch.tanh(torch.add(_0, h))return (_2, _2)

可以看到,if-else分支并没有被表示出来。为什么?
trace记录代码运行发生的操作,并构造一个ScriptModule。控制流中只有一种情况被记录了下来,其他情况都被忽略了。

这就需要用到torch.jit.script了:

scripted_gate = torch.jit.script(MyDecisionGate())my_cell = MyCell(scripted_gate)
scripted_cell = torch.jit.script(my_cell)print(scripted_gate.code)
print(scripted_cell.code)
def forward(self,x: Tensor) -> Tensor:if bool(torch.gt(torch.sum(x), 0)):_0 = xelse:_0 = torch.neg(x)return _0def forward(self,x: Tensor,h: Tensor) -> Tuple[Tensor, Tensor]:dg = self.dglinear = self.linear_0 = torch.add((dg).forward((linear).forward(x, ), ), h)new_h = torch.tanh(_0)return (new_h, new_h)

可以考到,控制流也被记录了下来。
现在让我们尝试运行该程序:

# New inputs
x, h = torch.rand(3, 4), torch.rand(3, 4)
print(scripted_cell(x, h))
(tensor([[ 0.5679,  0.5762,  0.2506, -0.0734],[ 0.5228,  0.7122,  0.6985, -0.0656],[ 0.6187,  0.4487,  0.7456, -0.0238]], grad_fn=<TanhBackward0>), tensor([[ 0.5679,  0.5762,  0.2506, -0.0734],[ 0.5228,  0.7122,  0.6985, -0.0656],[ 0.6187,  0.4487,  0.7456, -0.0238]], grad_fn=<TanhBackward0>))

trace和script的区别总结

  • torch.jit.tracetorch.jit.trace 用于将一个具体的输入示例追踪(trace)模型的一次计算过程,从而生成一个 TorchScript 模型。对于动态控制流(如条件语句),它只会记录每个分支中的一种情况。因此,它不适用于无固定形状输入、具有动态控制流的模型。

  • torch.jit.scripttorch.jit.script 用于将整个 PyTorch 模型转换为 TorchScript 模型,包括模型的所有逻辑和控制流。script适用于无固定形状输入、具有动态控制流的模型 。但是,它可能会把保存一些多余的代码, 产生额外的性能开销。

因此,可以将两者混合使用,扬长避短。

trace 和script 混合使用

torch.jit.tracetorch.jit.script 可以混合使用: 复杂模型中静态部分用torch.jit.trace进行转换, 动态部分用torch.jit.script 进行转换,以发挥各自的优势。以下是两个可能的情况:

  • torch.jit.script内联traced模块的代码,
class MyRNNLoop(torch.nn.Module):def __init__(self):super(MyRNNLoop, self).__init__()self.cell = torch.jit.trace(MyCell(scripted_gate), (x, h))def forward(self, xs):h, y = torch.zeros(3, 4), torch.zeros(3, 4)for i in range(xs.size(0)):y, h = self.cell(xs[i], h)return y, hrnn_loop = torch.jit.script(MyRNNLoop())
print(rnn_loop.code)
def forward(self,xs: Tensor) -> Tuple[Tensor, Tensor]:h = torch.zeros([3, 4])y = torch.zeros([3, 4])y0 = yh0 = hfor i in range(torch.size(xs, 0)):cell = self.cell_0 = (cell).forward(torch.select(xs, 0, i), h0, )y1, h1, = _0y0, h0 = y1, h1return (y0, h0)
  • torch.jit.trace内联scripted模块的代码,
class WrapRNN(torch.nn.Module):def __init__(self):super(WrapRNN, self).__init__()self.loop = torch.jit.script(MyRNNLoop())def forward(self, xs):y, h = self.loop(xs)return torch.relu(y)traced = torch.jit.trace(WrapRNN(), (torch.rand(10, 3, 4)))
print(traced.code)
def forward(self,xs: Tensor) -> Tensor:loop = self.loop_0, y, = (loop).forward(xs, )return torch.relu(y)

保存和加载模型

  • traced.save : 保存TorchScript

  • torch.jit.load : 加载TorchScript

traced.save('wrapped_rnn.pt')loaded = torch.jit.load('wrapped_rnn.pt')print(loaded)
print(loaded.code)
RecursiveScriptModule(original_name=WrapRNN(loop): RecursiveScriptModule(original_name=MyRNNLoop(cell): RecursiveScriptModule(original_name=MyCell(dg): RecursiveScriptModule(original_name=MyDecisionGate)(linear): RecursiveScriptModule(original_name=Linear)))
)
def forward(self,xs: Tensor) -> Tensor:loop = self.loop_0, y, = (loop).forward(xs, )return torch.relu(y)

参考:
https://pytorch.org/tutorials/beginner/Intro_to_TorchScript_tutorial.html

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/314744.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【Vue3+Tres 三维开发】02-Debug

预览 介绍 Debug 这里主要是讲在三维中的调试,同以前threejs中使用的lil-gui类似,TRESJS也提供了一套可视化参数调试的插件。使用方式和之前的组件相似。 使用 通过导入useTweakPane 即可 import { useTweakPane, OrbitControls } from "@tresjs/cientos"const {…

【六十】【算法分析与设计】用一道题目解决dfs深度优先遍历,dfs中节点信息,dfs递归函数模板进入前维护出去前回溯,唯一解的剪枝飞升返回值true

路径之谜 题目描述 小明冒充X星球的骑士,进入了一个奇怪的城堡。 城堡里边什么都没有,只有方形石头铺成的地面。 假设城堡地面是nn个方格。如下图所示。 按习俗,骑士要从西北角走到东南角。可以横向或纵向移动,但不能斜着音走,也不能跳跃。每走到一个新方格,就要向正北 方和正西…

ADOP带您科普什么是单纤双向BiDi光模块?一根光纤,双向通信:单纤双向模块的革命性技术。

单纤双向光模块&#xff08;也称为BiDi光模块&#xff09;是一种使用WDM&#xff08;波分复用&#xff09;双向传输技术的光模块&#xff0c;它在一根光纤上实现了同时进行光通道内的双向传输。相比常规光模块&#xff08;有两个光纤插孔&#xff09;&#xff0c;BiDi光模块只有…

基于Python+Selenium+Pytest的Dockerfile如何写

使用 Dockerfile 部署 Python 应用程序与 Selenium 测试 在本文中&#xff0c;我们将介绍如何使用 Dockerfile 部署一个 Python 应用程序&#xff0c;同时利用 Selenium 进行自动化测试。我们将使用官方的 Python 运行时作为父镜像&#xff0c;并在其中安装所需的依赖项和工具…

【Node.js工程师养成计划】之打造自己的脚手架工具

一、创建全局的自定义命令 1、打开一个空文件夹&#xff0c;新建一个bin文件夹&#xff0c;在bin文件夹下新建cli.js文件&#xff0c;js文件可以命名为cli.js&#xff08;您随意&#xff09; 2、在cli.js文件中的开头&#xff08;&#xff01;&#xff01;&#xff09;写下面这…

windows环境下安装Apache

首先apache官网下载地址&#xff1a;http://www.apachelounge.com/download/按照自己的电脑操作系统来安装 这里我安装的是win64 主版本是2.4的apache。 然后解压压缩包到一个全英文的路径下&#xff01;&#xff01;&#xff01;一定一定不要有中文 中文符号也不要有&#xff…

十一、Yocto集成tcpdump等网络工具

文章目录 Yocto集成tcpdump等网络工具networking layer集成 Yocto集成tcpdump等网络工具 本篇文章为基于raspberrypi 4B单板的yocto实战系列的第十一篇文章&#xff1a; 一、yocto 编译raspberrypi 4B并启动 二、yocto 集成ros2(基于raspberrypi 4B) 三、Yocto创建自定义的lay…

RabbitMQ工作模式(5) - 主题模式

概念 主题模式&#xff08;Topic Exchange&#xff09;是 RabbitMQ 中一种灵活且强大的消息传递模式&#xff0c;它允许生产者根据消息的特定属性将消息发送到一个交换机&#xff0c;并且消费者可以根据自己的需求来接收感兴趣的消息。主题交换机根据消息的路由键和绑定队列的路…

演示在一台Windows主机上运行两个Mysql服务器(端口号3306 和 3307),安装步骤详解

目录 在一台Windows主机上运行两个Mysql服务器&#xff0c;安装步骤详解因为演示需要两个 MySQL 服务器终端&#xff0c;我只有一个 3306 端口号的 MySQL 服务器&#xff0c;所以需要再创建一个 3307 的。创建一个3307端口号的MySQL服务器1、复制 mysql 的安装目录2、修改my.in…

NAT网络地址转换实验(华为)

思科设备参考&#xff1a;NAT网络地址转换实验&#xff08;思科&#xff09; 一&#xff0c;技术简介 NAT&#xff08;Network Address Translation&#xff09;&#xff0c;即网络地址转换技术&#xff0c;是一种在现代计算机网络中广泛应用的技术&#xff0c;主要用于有效管…

nvm基本使用

nvm基本使用 文章目录 nvm基本使用1.基本介绍2.下载地址3.常用指令 1.基本介绍 NVM是一个用于管理 Node.js 版本的工具。它允许您在同一台计算机上同时安装和管理多个 Node.js 版本&#xff0c;针对于不同的项目可能需要不同版本的 Node.js 运行环境。 NVM 主要功能&#xff…

百度智能云千帆 ModelBuilder 技术实践系列:通过 SDK 快速构建并发布垂域模型

​百度智能云千帆大模型平台&#xff08;百度智能云千帆大模型平台 ModelBuilder&#xff09;作为面向企业开发者的一站式大模型开发平台&#xff0c;自上线以来受到了广大开发者、企业的关注。至今已经上线收纳了超过 70 种预置模型服务&#xff0c;用户可以快速的调用&#x…

STM32的端口引脚的复用功能及重映射功能解析

目录 STM32的端口引脚的复用功能及重映射功能解析 复用功能 复用功能的初始化 重映射功能 重映射功能的初始化 复用功能和重映射的区别 部分重映射与完全重映射 补充 STM32的端口引脚的复用功能及重映射功能解析 复用功能 首先、我们可以这样去理解stm32引脚的复用功能…

docker容器技术篇:容器集群管理实战mesos+zookeeper+marathon(一)

容器集群管理实战mesoszookeepermarathon&#xff08;一&#xff09; mesos概述 1.1 Mesos是什么 Apache Mesos 是一个基于多资源调度的集群管理软件&#xff0c;提供了有效的、跨分布式应用或框架的资源隔离和共享&#xff0c;可以运行 Hadoop、Spark以及docker等。 1.2 为…

Flowable 基本用法

一. 什么是Flowable Flowable 是一个基于 Java 的开源工作流引擎&#xff0c;用于实现和管理业务流程。它提供了强大的工作流引擎和一套丰富的工具&#xff0c;使开发人员能够轻松地建模、部署、执行和监控各种类型的业务流程。Flowable 是 Activiti 工作流引擎的一个分支&am…

服务网关GateWay基础

1. 网关基础介绍1.1 网关是什么1.2 为啥要用网关1.3 常见的网关组件NginxNetflix ZuulSpring Cloud GatewayKongAPISIX综合比较 2. gateWay的使用2.1 springCloud整合gateway2.2 GateWay的相关用法2.3 GateWay路由使用示例基本用法转发/重定向负载请求动态路由 2.5 断言(Predic…

使用Screenshots安装Fedora 40版本详细教程

Fedora 40是Fedora操作系统的最新版本&#xff0c;于 2024 年 4 月 23 日发布&#xff0c;是一个社区支持的 Linux 发行版&#xff0c;以其创新功能、领先技术和活跃的社区支持而闻名。 在本指南中&#xff0c;我们将引导您完成安装Fedora 40 Server的分步过程&#xff0c;确保…

Docker之存储配置与管理

一、容器本地配置与Docker存储驱动 每个容器都被自动分配了本地存储&#xff0c;也就是内部存储。容器由一个可写容器层和若干只读镜像层组成&#xff0c;容器的数据就存放在这些层中。 容器本地存储采用的是联合文件系统。这种文件系统将其他文件系统合并到一个联合挂载点&a…

【c++】深入剖析与动手实践:C++中Stack与Queue的艺术

&#x1f525;个人主页&#xff1a;Quitecoder &#x1f525;专栏&#xff1a;c笔记仓 朋友们大家好&#xff0c;本篇文章我们来到STL新的内容&#xff0c;stack和queue 目录 1. stack的介绍与使用函数介绍例题一&#xff1a;最小栈例题二&#xff1a;栈的压入、弹出队列栈的模…

springcloud按版本发布微服务达到不停机更新的效果

本文基于以下环境完成 spring-boot 2.3.2.RELEASEspring-cloud Hoxton.SR9spring-cloud-alibaba 2.2.6.RELEASEspring-cloud-starter-gateway 2.2.6.RELEASEspring-cloud-starter-loadbalancer 2.2.6.RELEASEnacos 2.0.3 一、思路 实现思路&#xff1a; 前端项目在请求后端接…