文章目录
- (十三)模型部署: TorchScript
- Pytorch动态图的优缺点
- TorchScript
- Pytorch模型转换为TorchScript
- torch.jit.trace
- torch.jit.script
- trace和script的区别总结
- trace 和script 混合使用
- 保存和加载模型
(十三)模型部署: TorchScript
Pytorch动态图的优缺点
与Tensorflow使用静态计算图不同,PyTorch 使用的是动态计算图:
动态图允许在运行时渐进地构建计算图,使得模型设计更加灵活。开发者可以使用 Python 的控制流结构(如循环、条件语句等)来动态地定义模型的结构,从而更容易实现复杂的模型逻辑。
这种计算方式更直观,更pythonic。开发者可以更容易地理解和调试模型各个模块,快速地修改、迭代模型。
然而,与静态图相比,动态图的执行效率可能会较低。因为动态图难以进行一些计算图的优化,如运算符融合、图优化等。而且,动态图依赖于Python 环境。这些因素使得动态图不适合在低延迟要求较高的生产环境下部署。
因此,在部署Pytorch训练后的模型时,需要将动态图转换为静态图,这就要用到TorchScript。
TorchScript
TorchScript是PyTorch模型的一种静态图表示形式,支持模型的部署优化、跨平台部署以及与其他深度学习框架的集成:
- 模型的部署优化:TorchScript 可以帮助优化 PyTorch 模型以提高性能和效率。通过将模型转换为静态图形式,TorchScript 可以应用各种优化技术,如运算符融合、图优化等,从而加速模型执行并降低内存消耗。
- 跨平台部署:将模型转换为 TorchScript 格式可以实现跨平台部署,模型可以在没有 Python 环境的情况下运行。这对于在生产环境中部署模型到服务器、移动设备或边缘设备上非常有用。
- 与其他框架集成:通过将 PyTorch 模型转换为 TorchScript 格式,可以更方便地与其他深度学习框架进行交互。例如,可以将TorchScript 进一步转换为 ONNX 格式,从而与 TensorFlow 等其他框架进行集成和交互操作。
Pytorch模型转换为TorchScript
torch.jit.trace
和 torch.jit.script
是 PyTorch 中用于模型转换为 TorchScript 格式的工具,但它们有不同的作用和使用场景。
torch.jit.trace
通过torch.jit.trace
将 没有控制流的MyCell
模块转化为TorchScript:
import torch # This is all you need to use both PyTorch and TorchScript!torch.manual_seed(191009) # set the seed for reproducibilityclass MyCell(torch.nn.Module):def __init__(self):super(MyCell, self).__init__()self.linear = torch.nn.Linear(4, 4)def forward(self, x, h):new_h = torch.tanh(self.linear(x) + h)return new_h, new_hmy_cell = MyCell()
x, h = torch.rand(3, 4), torch.rand(3, 4)
traced_cell = torch.jit.trace(my_cell, (x, h))
print(traced_cell)
MyCell(original_name=MyCell(linear): Linear(original_name=Linear)
)
torch.jit.trace
调用了my_cell
,记录了模块计算时发生的操作,并创建了一个torch.jit.ScriptModule
的实例(TracedModule
是其实例)traced_cell
。traced_cell
记录了my_cell
的计算图。我们可以使用.graph
属性来查看:
print(traced_cell.graph)
graph(%self.1 : __torch__.MyCell,%x : Float(3, 4, strides=[4, 1], requires_grad=0, device=cpu),%h : Float(3, 4, strides=[4, 1], requires_grad=0, device=cpu)):%linear : __torch__.torch.nn.modules.linear.Linear = prim::GetAttr[name="linear"](%self.1)%20 : Tensor = prim::CallMethod[name="forward"](%linear, %x)%11 : int = prim::Constant[value=1]() # /var/lib/workspace/beginner_source/Intro_to_TorchScript_tutorial.py:189:0%12 : Float(3, 4, strides=[4, 1], requires_grad=1, device=cpu) = aten::add(%20, %h, %11) # /var/lib/workspace/beginner_source/Intro_to_TorchScript_tutorial.py:189:0%13 : Float(3, 4, strides=[4, 1], requires_grad=1, device=cpu) = aten::tanh(%12) # /var/lib/workspace/beginner_source/Intro_to_TorchScript_tutorial.py:189:0%14 : (Float(3, 4, strides=[4, 1], requires_grad=1, device=cpu), Float(3, 4, strides=[4, 1], requires_grad=1, device=cpu)) = prim::TupleConstruct(%13, %13)return (%14)
然而,图中包含的大多数信息对我们没有用处。我们可以使用.code
属性对其进行Python语法解释:
print(traced_cell.code)
def forward(self,x: Tensor,h: Tensor) -> Tuple[Tensor, Tensor]:linear = self.linear_0 = torch.tanh(torch.add((linear).forward(x, ), h))return (_0, _0)
调用traced_cell
会产生与Python模块实例my_cell()
相同的结果:
print(my_cell(x, h))
print(traced_cell(x, h))
(tensor([[-0.2541, 0.2460, 0.2297, 0.1014],[-0.2329, -0.2911, 0.5641, 0.5015],[ 0.1688, 0.2252, 0.7251, 0.2530]], grad_fn=<TanhBackward0>), tensor([[-0.2541, 0.2460, 0.2297, 0.1014],[-0.2329, -0.2911, 0.5641, 0.5015],[ 0.1688, 0.2252, 0.7251, 0.2530]], grad_fn=<TanhBackward0>))
(tensor([[-0.2541, 0.2460, 0.2297, 0.1014],[-0.2329, -0.2911, 0.5641, 0.5015],[ 0.1688, 0.2252, 0.7251, 0.2530]], grad_fn=<TanhBackward0>), tensor([[-0.2541, 0.2460, 0.2297, 0.1014],[-0.2329, -0.2911, 0.5641, 0.5015],[ 0.1688, 0.2252, 0.7251, 0.2530]], grad_fn=<TanhBackward0>))
torch.jit.script
我们先尝试通过torch.jit.trace
将 带有控制流的MyCell
模块转化为TorchScript:
class MyDecisionGate(torch.nn.Module):def forward(self, x):if x.sum() > 0:return xelse:return -xclass MyCell(torch.nn.Module):def __init__(self, dg):super(MyCell, self).__init__()self.dg = dgself.linear = torch.nn.Linear(4, 4)def forward(self, x, h):new_h = torch.tanh(self.dg(self.linear(x)) + h)return new_h, new_hmy_cell = MyCell(MyDecisionGate())
traced_cell = torch.jit.trace(my_cell, (x, h))print(traced_cell.dg.code)
print(traced_cell.code)
/var/lib/workspace/beginner_source/Intro_to_TorchScript_tutorial.py:261: TracerWarning:Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!def forward(self,argument_1: Tensor) -> NoneType:return Nonedef forward(self,x: Tensor,h: Tensor) -> Tuple[Tensor, Tensor]:dg = self.dglinear = self.linear_0 = (linear).forward(x, )_1 = (dg).forward(_0, )_2 = torch.tanh(torch.add(_0, h))return (_2, _2)
可以看到,if-else
分支并没有被表示出来。为什么?
trace
记录代码运行发生的操作,并构造一个ScriptModule
。控制流中只有一种情况被记录了下来,其他情况都被忽略了。
这就需要用到torch.jit.script
了:
scripted_gate = torch.jit.script(MyDecisionGate())my_cell = MyCell(scripted_gate)
scripted_cell = torch.jit.script(my_cell)print(scripted_gate.code)
print(scripted_cell.code)
def forward(self,x: Tensor) -> Tensor:if bool(torch.gt(torch.sum(x), 0)):_0 = xelse:_0 = torch.neg(x)return _0def forward(self,x: Tensor,h: Tensor) -> Tuple[Tensor, Tensor]:dg = self.dglinear = self.linear_0 = torch.add((dg).forward((linear).forward(x, ), ), h)new_h = torch.tanh(_0)return (new_h, new_h)
可以考到,控制流也被记录了下来。
现在让我们尝试运行该程序:
# New inputs
x, h = torch.rand(3, 4), torch.rand(3, 4)
print(scripted_cell(x, h))
(tensor([[ 0.5679, 0.5762, 0.2506, -0.0734],[ 0.5228, 0.7122, 0.6985, -0.0656],[ 0.6187, 0.4487, 0.7456, -0.0238]], grad_fn=<TanhBackward0>), tensor([[ 0.5679, 0.5762, 0.2506, -0.0734],[ 0.5228, 0.7122, 0.6985, -0.0656],[ 0.6187, 0.4487, 0.7456, -0.0238]], grad_fn=<TanhBackward0>))
trace和script的区别总结
-
torch.jit.trace:
torch.jit.trace
用于将一个具体的输入示例追踪(trace)模型的一次计算过程,从而生成一个 TorchScript 模型。对于动态控制流(如条件语句),它只会记录每个分支中的一种情况。因此,它不适用于无固定形状输入、具有动态控制流的模型。 -
torch.jit.script:
torch.jit.script
用于将整个 PyTorch 模型转换为 TorchScript 模型,包括模型的所有逻辑和控制流。script
适用于无固定形状输入、具有动态控制流的模型 。但是,它可能会把保存一些多余的代码, 产生额外的性能开销。
因此,可以将两者混合使用,扬长避短。
trace 和script 混合使用
torch.jit.trace
和 torch.jit.script
可以混合使用: 复杂模型中静态部分用torch.jit.trace
进行转换, 动态部分用torch.jit.script
进行转换,以发挥各自的优势。以下是两个可能的情况:
torch.jit.script
内联traced模块的代码,
class MyRNNLoop(torch.nn.Module):def __init__(self):super(MyRNNLoop, self).__init__()self.cell = torch.jit.trace(MyCell(scripted_gate), (x, h))def forward(self, xs):h, y = torch.zeros(3, 4), torch.zeros(3, 4)for i in range(xs.size(0)):y, h = self.cell(xs[i], h)return y, hrnn_loop = torch.jit.script(MyRNNLoop())
print(rnn_loop.code)
def forward(self,xs: Tensor) -> Tuple[Tensor, Tensor]:h = torch.zeros([3, 4])y = torch.zeros([3, 4])y0 = yh0 = hfor i in range(torch.size(xs, 0)):cell = self.cell_0 = (cell).forward(torch.select(xs, 0, i), h0, )y1, h1, = _0y0, h0 = y1, h1return (y0, h0)
torch.jit.trace
内联scripted模块的代码,
class WrapRNN(torch.nn.Module):def __init__(self):super(WrapRNN, self).__init__()self.loop = torch.jit.script(MyRNNLoop())def forward(self, xs):y, h = self.loop(xs)return torch.relu(y)traced = torch.jit.trace(WrapRNN(), (torch.rand(10, 3, 4)))
print(traced.code)
def forward(self,xs: Tensor) -> Tensor:loop = self.loop_0, y, = (loop).forward(xs, )return torch.relu(y)
保存和加载模型
-
traced.save
: 保存TorchScript -
torch.jit.load
: 加载TorchScript
traced.save('wrapped_rnn.pt')loaded = torch.jit.load('wrapped_rnn.pt')print(loaded)
print(loaded.code)
RecursiveScriptModule(original_name=WrapRNN(loop): RecursiveScriptModule(original_name=MyRNNLoop(cell): RecursiveScriptModule(original_name=MyCell(dg): RecursiveScriptModule(original_name=MyDecisionGate)(linear): RecursiveScriptModule(original_name=Linear)))
)
def forward(self,xs: Tensor) -> Tensor:loop = self.loop_0, y, = (loop).forward(xs, )return torch.relu(y)
参考:
https://pytorch.org/tutorials/beginner/Intro_to_TorchScript_tutorial.html