PyTorch 内 LibTorch/TorchScript 的使用

PyTorch 内 LibTorch/TorchScript 的使用

  • 1. .pt .pth .bin .onnx 格式
    • 1.1 模型的保存与加载到底在做什么?
    • 1.2 为什么要约定格式?
    • 1.3 格式汇总
      • 1.3.1 .pt .pth 格式
      • 1.3.2 .bin 格式
      • 1.3.3 直接保存完整模型
      • 1.3.4 .onnx 格式
      • 1.3.5 jit.trace
      • 1.3.6 jit.script
    • 1.4 总结
  • 2. TorchScript 的转换
    • 2.1 jit trace 注意事项
    • 2.2 jit trace 验证技巧
    • 2.3 混合使用 trace 和 script
    • 2.4 trace 和 script 的性能
    • 2.5 总结
  • 3. LibTorch 的使用
    • 3.1 LibTorch 的链接
    • 3.2 接口和实现

Reference:

  1. [Pytorch].pth转.pt文件
  2. Pytorch格式 .pt .pth .bin .onnx 详解
  3. pytorch 基于tracing/script方式转ONNX

在这里插入图片描述

1. .pt .pth .bin .onnx 格式

1.1 模型的保存与加载到底在做什么?

我们在使用pytorch构建模型并且训练完成后,下一步要做的就是把这个模型放到实际场景中应用,或者是分享给其他人学习、研究、使用。因此,我们开始思考一个问题,提供哪些模型信息,能够让对方能够完全复现我们的模型?

  • 模型代码
    1. 包含了我们如何定义模型的结构,包括模型有多少层/每层有多少神经元等等信息;
    2. 包含了我们如何定义的训练过程,包括epoch batch_size等参数;
    3. 包含了我们如何加载数据和使用;
    4. 包含了我们如何测试评估模型。
  • 模型参数:提供了模型代码之后,对方确实能够复现模型,但是运行的参数需要重新训练才能得到,而没有办法在我们的模型参数基础上继续训练,因此对方还希望我们能够把模型的参数也保存下来给对方。
    1. 包含model.state_dict(),这是模型每一层可学习的节点的参数,比如weight/bias;
    2. 包含optimizer.state_dict(),这是模型的优化器中的参数;
    3. 包含我们其他参数信息,如epoch/batch_size/loss等。
  • 数据集
    1. 包含了我们训练模型使用的所有数据;
    2. 可以提示对方如何去准备同样格式的数据来训练模型。
  • 使用文档
    1. 根据使用文档的步骤,每个人都可以重现模型;
    2. 包含了模型的使用细节和我们相关参数的设置依据等信息。

可以看到,根据我们提供的模型代码/模型参数/数据集/使用文档,我们就可以有理由相信对方是有手就会了,那么目的就达到了。

现在我们反转一下思路,我们希望别人给我们提供模型的时候也能够提供这些信息,那么我们就可以拿捏住别人的模型了。

1.2 为什么要约定格式?

根据上一段的思路,我们知道模型重现的关键是模型结构/模型参数/数据集,那么我们提供或者希望别人提供这些信息,需要一个交流的规范,这样才不会1000个人给出1000种格式,而 .pt .pth .bin 以及 .onnx 就是约定的格式。

torch.save: Saves a serialized object to disk. This function uses Python’s pickle utility for serialization. Models, tensors, and dictionaries of all kinds of objects can be saved using this function.

不同的后缀只是用于提示我们文件可能包含的内容,但是具体的内容需要看模型提供者编写的 README.md 才知道。而在使用 torch.load() 方法加载模型信息的时候,并不是根据文件的后缀进行的读取,而是根据文件的实际内容自动识别的,因此对于 torch.load() 方法而言,不管你把后缀改成是什么,只要文件是对的都可以读取

torch.load: Uses pickle’s unpickling facilities to deserialize pickled object files to memory. This function also facilitates the device to load the data into.

1.3 格式汇总

格式解释适用场景可对应的后缀
.pt 或 .pthPyTorch 的默认模型文件格式,用于保存和加载完整的 PyTorch 模型,包含模型的结构和参数等信息需要保存和加载完整的 PyTorch 模型的场景,例如在训练中保存最佳的模型或在部署中加载训练好的模型.pt 或 .pth
.bin一种通用的二进制格式,可以用于保存和加载各种类型的模型和数据需要将 PyTorch 模型转换为通用的二进制格式的场景.bin
ONNX一种通用的模型交换格式,可以用于将模型从一个深度学习框架转换到另一个深度学习框架或硬件平台。在 PyTorch 中,可以使用 torch.onnx.export 函数将 PyTorch 模型转换为 ONNX 格式需要将 PyTorch 模型转换为其他深度学习框架或硬件平台可用的格式的场景.onnx
TorchScriptPyTorch 提供的一种序列化和优化模型的方法,可以将 PyTorch 模型转换为一个序列化的程序,并使用 JIT 编译器对模型进行优化。在 PyTorch 中,可以使用 torch.git.trace 或 torch.git.script 函数将 PyTorch 模型转换为 TorchScript 格式需要将 PyTorch 模型序列化和优化,并在没有 Python 环境的情况下运行模型的场景.pt 或 .pth

1.3.1 .pt .pth 格式

一个完整的 PyTorch 模型文件,包含了如下参数:

  • model_state_dict:模型参数
  • optimizer_state_dict:优化器的状态
  • epoch:当前的训练轮数
  • loss:当前的损失值

下面是一个 .pt 文件的保存和加载示例(注意,后缀也可以是 .pth):

  • .state_dict():包含所有的参数和持久化缓存的字典,model 和 optimizer 都有这个方法
  • torch.save():将所有的组件保存到文件中

模型保存

import torch
import torch.nn as nn# 定义一个简单的模型
class Net(nn.Module):def __init__(self):super(Net, self).__init__()self.fc1 = nn.Linear(10, 5)self.fc2 = nn.Linear(5, 1)def forward(self, x):x = self.fc1(x)x = self.fc2(x)return xmodel = Net()optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)# 初始化优化器loss = nn.MSELoss()# 初始化损失函数PATH = "model.pth" # 保存路径# 保存模型
torch.save({'epoch': 10,'model_state_dict': model.state_dict(),'optimizer_state_dict': optimizer.state_dict(),'loss': loss,}, PATH)

netron 可得:
在这里插入图片描述

模型加载

import torch
import torch.nn as nn# 定义同样的模型结构
class Net(nn.Module):def __init__(self):super(Net, self).__init__()self.fc1 = nn.Linear(10, 5)self.fc2 = nn.Linear(5, 1)def forward(self, x):x = self.fc1(x)x = self.fc2(x)return x# 加载模型
model = Net()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
PATH = "model.pth"
checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']
model.eval()

1.3.2 .bin 格式

.bin 文件是一个二进制文件,可以保存 PyTorch 模型的参数和持久化缓存。.bin 文件的大小较小,加载速度较快,因此在生产环境中使用较多。

下面是一个.bin文件的保存和加载示例(注意:也可以使用 .pt .pth 后缀—后缀无意义):
保存模型

import torch
import torch.nn as nn# 定义一个简单的模型
class Net(nn.Module):def __init__(self):super(Net, self).__init__()self.fc1 = nn.Linear(10, 5)self.fc2 = nn.Linear(5, 1)def forward(self, x):x = self.fc1(x)x = self.fc2(x)return xmodel = Net()
# 保存参数到.bin文件
torch.save(model.state_dict(), PATH)

加载模型

import torch
import torch.nn as nn# 定义相同的模型结构
class Net(nn.Module):def __init__(self):super(Net, self).__init__()self.fc1 = nn.Linear(10, 5)self.fc2 = nn.Linear(5, 1)def forward(self, x):x = self.fc1(x)x = self.fc2(x)return x# 加载.bin文件
model = Net()
model.load_state_dict(torch.load(PATH))
model.eval()

1.3.3 直接保存完整模型

可以看出来,我们在之前的保存方式中,都是保存了 .state_dict(),但是没有保存模型的结构,在其他地方使用的时候,必须先重新定义相同结构的模型(或兼容模型),才能够加载模型参数进行使用,如果我们想直接把整个模型都保存下来,避免重新定义模型,可以按如下操作:
保存模型

PATH = "entire_model.pt"
# PATH = "entire_model.pth"
# PATH = "entire_model.bin"
torch.save(model, PATH)

netron 可得:
在这里插入图片描述

可以看到与上面仅保存参数的方式相比,多了很多信息。

加载模型

model = torch.load("entire_model.pt")
model.eval()

1.3.4 .onnx 格式

上述保存的文件可以通过 PyTorch 提供的 torch.onnx.export 函数转化为ONNX格式,这样可以在其他深度学习框架中使用 PyTorch 训练的模型。转化方法如下:

import torch
import torch.onnx# 将模型保存为.bin文件
model = torch.nn.Linear(3, 1)
torch.save(model.state_dict(), "model.bin")
# torch.save(model.state_dict(), "model.pt")
# torch.save(model.state_dict(), "model.pth")# 将.bin文件转化为ONNX格式
model = torch.nn.Linear(3, 1)
model.load_state_dict(torch.load("model.bin"))
# model.load_state_dict(torch.load("model.pt"))
# model.load_state_dict(torch.load("model.pth"))
example_input = torch.randn(1, 3)
torch.onnx.export(model, example_input, "model.onnx", input_names=["input"], output_names=["output"])

加载 ONNX 格式的代码可以参考以下示例代码(注意 ONNX 只能推理不能训练,不包含反向信息的):

import onnx
import onnxruntime# 加载ONNX文件
onnx_model = onnx.load("model.onnx")# 将ONNX文件转化为ORT格式
ort_session = onnxruntime.InferenceSession("model.onnx")# 输入数据
input_data = np.random.random(size=(1, 3)).astype(np.float32)# 运行模型
outputs = ort_session.run(None, {"input": input_data})# 输出结果
print(outputs)

注意,需要安装 onnxonnxruntime 两个 Python 包。此外,还需要使用 numpy 等其他常用的科学计算库。

1.3.5 jit.trace

保存模型

import torch
import torch.nn as nn# 定义一个简单的模型
class Net(nn.Module):def __init__(self):super(Net, self).__init__()self.fc1 = nn.Linear(10, 5)self.fc2 = nn.Linear(5, 1)def forward(self, x):x = self.fc1(x)x = self.fc2(x)return xmodel = Net()optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9) # 初始化优化器
loss = nn.MSELoss() # 初始化损失函数
model.eval()PATH = "model_trace.pth"# 保存模型
example = torch.rand(1, 10)
traced_module = torch.jit.trace(model, example)
traced_module.save(PATH)

在这里插入图片描述

1.3.6 jit.script

保存模型

import torch
import torch.nn as nn# 定义一个简单的模型
class Net(nn.Module):def __init__(self):super(Net, self).__init__()self.fc1 = nn.Linear(10, 5)self.fc2 = nn.Linear(5, 1)def forward(self, x):x = self.fc1(x)x = self.fc2(x)return xmodel = Net()optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9) # 初始化优化器
loss = nn.MSELoss() # 初始化损失函数
model.eval()PATH = "model_script.pth" # 保存路径# 保存模型
scripted_module = torch.jit.script(model)
scripted_module.save(PATH)

netron 可得:
在这里插入图片描述

1.4 总结

综上,PyTorch 可以导出的模型的几种后缀格式,但是模型导出的关键并不是后缀,而是到处时候提供的信息到底是什么,只要知道了模型的 model.state_dict()optimizer.state_dict(),以及相应的epoch batch_size loss等信息,我们就能够重建出模型,至于要导出哪些信息,就取决于你了,务必在 readme.md 中写清楚,导出了哪些信息。

保存场景保存方法文件后缀
整个模型(保存模型结构)model = Net()
torch.save(model, PATH)
.pt .pth .bin
仅模型参数(不保存模型结构)model = Net()
torch.save(model.state_dict(), PATH)
.pt .pth .bin
checkpoints使用model = Net()
torch.save({‘epoch’:10,‘model_state_dict’:model.state_dict(),‘optimizer_state_dict’: optimizer.state_dict(),‘loss’: loss,}, PATH)
.pt .pth .bin
ONNX通用保存model = Net()
model.load_state_dict(torch.load(“model.bin”))
example_input = torch.randn(1, 3)
torch.onnx.export(model, example_input, “model.onnx”, input_names=[“input”], output_names=[“output”])
.onnx
TorchScript 无 Python 环境使用model = Net()
model_scripted = torch.jit.script(model) # Export to TorchScript
model_scripted.save(‘model_scripted.pt’)
model = torch.jit.load(‘model_scripted.pt’)
model.eval()
.pt .pth

2. TorchScript 的转换

上文内提到 .pthpt 等价,而且后缀主要用于提示。不过相对来说,PyTorch 的模型文件一般保存为 .pth 文件的更多一点,而 C++ 接口一般读取的是 .pt 文件,因此,C++ 在调用 PyTorch 训练好的模型文件的时候,就需要转换为以 .pt 为代表的 TorchScript 文件,才能够读取。

Script mode 通过 torch.jit.trace 或者 torch.jit.script 来调用。这两个函数都是将 Python 代码转换为 TorchScript 的两种不同的方法。

  • torch.jit.trace:将一个特定的输入(通常是一个张量,需要我们提供一个input)传递给一个 PyTorch 模型,torch.jit.trace 会跟踪此 input 在 model 中的计算过程,然后将其转换为 Torch 脚本。这个方法适用于那些在静态图中可以完全定义的模型,例如具有固定输入大小的神经网络。通常用于转换预训练模型。

  • torch.jit.script 直接将 Python 函数(或者一个 Python 模块)通过 Python 语法规则和编译转换为 Torch 脚本。torch.jit.script 更适用于动态图模型,这些模型的结构和输入可以在运行时发生变化。例如,对于 RNN 或者一些具有可变序列长度的模型,使用 torch.jit.script 会更为方便。

在通常情况下,更应该倾向于使用 torch.jit.trace 而不是 torch.jit.script

在模型部署方面,ONNX 被大量使用。而导出 ONNX 的过程,也是 model 进行 torch.jit.trace 的过程,因此这里我们把 torch 的 trace 做稍微详细一点的介绍。

2.1 jit trace 注意事项

为了能够把模型编写的更能够被 jit trace,需要对代码做一些妥协,例如:

  1. 如果 model 中有 DataParallel 的子模块,或者 model 中有将 tensors 转换为 numpy arrays,或者调用了 OpenCV 的函数等,这种情况下,model 不是一个正确的在单个设备上、正确连接的 graph,这种情况下,不管是使用 torch.jit.script 还是 torch.jit.trace 都不能 trace 出正确的 TorchScript 来。

  2. model 的输入输出应该是 Union[Tensor, Tuple[Tensor], Dict[str, Tensor]] 的类型,而且在 dict 中的值,应该是同样的类型。但是对于 model 中间子模块的输入输出,可以是任意类型,例如 dicts of Any, classes, kwargs 以及 Python 支持的都可以。对于 model 输入输出类型的限制是比较容易满足的,在Detectron2中,有类似的例子:

    outputs = model(inputs)   # inputs和outputs是python的类型, 例如dictsor classes
    # torch.jit.trace(model, inputs)  # 失败!trace只支持Union[Tensor,Tuple[Tensor], Dict[str, Tensor]]类型
    adapter = TracingAdapter(model, inputs)  # 使用Adapter,将model inputs包装为trace支持的类型
    traced = torch.jit.trace(adapter, adapter.flattened_inputs)  # 现在以trace成功# Traced model的输出只能是tuple tensors类型:
    flattened_outputs = traced(*adapter.flattened_inputs)
    # 再通过adapter转换为想要的输出类型
    new_outputs = adapter.outputs_schema(flattened_outputs)
    
  3. 一些数值类型的问题。比如下面的代码片段:

    import torch
    a=torch.tensor([1,2])
    print(type(a.size(0)))
    print(type(a.size()[0]))
    print(type(a.shape[0]))
    

    在eager mode下,这几个返回值的类型都是int型。上面代码的输出为:

    <class 'int'>
    <class 'int'>
    <class 'int'>
    

    但是在 trace mode 下,这几个表达式的返回值类型都是 Tensor 类型。因此,有些表达式使用不当,如果在 trace 过程中,一些 shape 表达式的返回值类型是 int 型,那么可能造成这块代码没有被 trace。在代码中,可以通过使用 torch.jit.is_tracing 来检查这块代码在 trace mode 下有没有被执行。

  4. 由于动态的 control flow,造成模型没有被完整的 trace。看下面的例子:

    import torchdef f(x):return torch.sqrt(x) if x.sum() > 0 else torch.square(x)m = torch.jit.trace(f, torch.tensor(3))
    print(m.code)
    

    输出为:

    def f(x: Tensor) -> Tensor:return torch.sqrt(x)
    

    可以看到 trace 后的 model 只保留了一条分支。因此由于输入造成的 dynamic 的 control flow,trace 后容易出现错误。

    这种情况下,我们可以使用 torch.jit.script 来进行 TorchScript 的转换。

    import torchdef f(x):return torch.sqrt(x) if x.sum() > 0 else torch.square(x)m = torch.jit.script(f)
    print(m.code)
    

    输出为:

    def f(x: Tensor) -> Tensor:if bool(torch.gt(torch.sum(x), 0)):_0 = torch.sqrt(x)else:_0 = torch.square(x)return _0
    

    在大多数情况下,我们应该使用 torch.jit.trace,但是像上面的这种 dynamic control flow 的情况,我们可以混合使用 torch.jit.tracetorch.jit.script,在后面会进行阐述
    另外在一些 Blog 中,对于 dynamic control flow 的定义是有错误的,例如 if x[0] == 4: x += 1 是 dynamic control flow,但是:

    model: nn.Sequential = ...
    for m in model:x = m(x)
    

    以及:

    class A(nn.Module):backbone: nn.Modulehead: Optiona[nn.Module]def forward(self, x):x = self.backbone(x)if self.head is not None:x = self.head(x)return x
    

    都不是 dynamic control flowdynamic control flow 是由于对输入条件的判断造成的不同分支的执行

  5. trace 过程中,将变量 trace 成了常量。看下面一个例子:

    import torch
    a, b = torch.rand(1), torch.rand(2)def f1(x): return torch.arange(x.shape[0])
    def f2(x): return torch.arange(len(x))print(torch.jit.trace(f1, a)(b))
    # 输出: tensor([0, 1])
    # 可以看到trace后的model是没问题的,这里使用变量a作为torch.jit.trace的example input,然后将转换后的TorchScript用变量b作为输入,正常情况下,b的shape是2维的,因此返回值是tensor([0,1])是正确的print(torch.jit.trace(f2, a)(b))
    # 输出:
    # TracerWarning: Using len to get tensor shape might cause the trace to be incorrect. Recommended usage would be tensor.shape[0]. Passing a tensor of different shape might lead to errors or silently give incorrect results.
    # tensor([0])
    # 可以看到这个输出结果是错误的,b的维度是2维,输出应该是tensor([0,1]),这里torch.jit.trace也提示了,使用len可能会造成不正确的trace。# 我们打印一下两者的区别
    print(torch.jit.trace(f1, a).code, '\n',torch.jit.trace(f2, a).code)
    # 输出
    # def f1(x: Tensor) -> Tensor:
    #   _0 = ops.prim.NumToTensor(torch.size(x, 0))
    #   _1 = torch.arange(annotate(number, _0), dtype=None, layout=None, device=torch.device("cpu"), pin_memory=False)
    #   return _1#  def f2(x: Tensor) -> Tensor:
    #   _0 = torch.arange(1, dtype=None, layout=None, device=torch.device("cpu"), pin_memory=False)
    #   return _0# TracerWarning: Using len to get tensor shape might cause the trace to be incorrect. Recommended usage would be tensor.shape[0]. Passing a tensor of different shape might lead to errors or silently give incorrect results.# 从trace的code中可以看出,使用x.shape这种方式,在trace后的code里面,是有shape的一个变量值存在的,但是直接使用len这种方式,trace后的code里面,就直接是1
    

    我们导出 ONNX 的过程,也是进行 torch.jit.trace 的过程,在导出 ONNX 的时候,有时候也会遇到

    TracerWarning: Using len to get tensor shape might cause the trace to be incorrect. Recommended usage would be tensor.shape[0]. Passing a tensor of different shape might lead to errors or silently give incorrect results.

    这样的提示信息,这时候要检查一下代码中是不是有可能 trace 过程中,变量会被当做常量的情况,有可能会导致导出的 ONNX 精度异常。

    • 关于 ONNX
      ONNX 默认基于 trace 的方式,运行一次模型,记录下和 tensor 的相关操作。trace 将不会捕获根据输入数据而改变的行为。比如 if 语句,只会记录执行的那一条分支,同样的,for 循环的次数,导出与跟踪运行完全相同的静态图。如果要使用动态控制流导出模型,则需要使用 torch.jit.script
      torch.jit.script:真正的去编译,在 PYTHON 的 AST 语法树做语法分析句法分析。因此可以使用if等动态控制流。返回 ScriptModule。
      torch.onnx.export 在运行时,先判断是否是 SriptModule,如果不是,则进行 torch.jit.trace,因此 export 需要一个随机生成的输入参数。
      import torch.nn as nn
      import torch
      import torch.nn.functional as F
      import cv2
      import numpy as np
      import onnx
      import onnxruntime as ort#from torch.onnx import register_custom_op_symbolic # 私有层支持class test_net(nn.Module):def __init__(self,):super(test_net, self).__init__()#self.model = nn.MaxPool3d(kernel_size=(1,3,3), stride=(2,1,2))#self.model = nn.AvgPool3d(kernel_size=(1,3,3), stride=(2,1,2)) #-> AveragePoolself.model = nn.Conv3d(3,64,kernel_size=(1,3,3), stride=(2,1,2))self.relu = nn.ReLU()self.relu6 = nn.ReLU6()self.relu66 = nn.ReLU6()def forward(self, x):out1 = self.model(x)f_mean = torch.mean(out1) # -> ReduceMean#f_mean = torch.mean(out1).item() # item()会将f_mean转换为常数 会丢失 mean操作# script模式转onnx会报错 torch._C._jit_pass_erase_number_types(graph) RuntimeError: Unknown number type: Scalarout2 = torch.div(out1, f_mean)#outlist = list()#for i in range(3):#    if i in [0]:#        #outlist.append(nn.ReLU()(out2))  # script模式下报错 类对象要提前构建#        outlist.append(self.relu(out2))   # scrip_to_onnx 报错 找不到25 BUG#    else:#        #outlist.append(nn.ReLU6()(out2))#        outlist.append(self.relu6(out2))#out = torch.cat(outlist)# 上述 for循环构图在tracing模式下会展开# script模式下难转换,报错# 手动平铺o1 = self.relu(out2)o2 = self.relu6(out2)#o3 = self.relu6(out2)   # script模式下被优化掉了 BUGo3 = self.relu66(out2)   # script模式下被优化掉了out = torch.cat([o1,o2,o3])return out# 模型构建和运行
      imgh, imgw = 24, 94
      net = test_net().eval() # 若存在batchnorm、dropout层则一定要eval() 使得BN层参数不更新
      dummy_input = torch.randn(1,3,3,imgh, imgw)# n c d h w
      torch_out = net.forward(dummy_input)# net(dummy_input)# export onnx
      dynamic_axes = {'input': {3: 'height', 4: 'width'}, 'output': {3: 'height', 4: 'width'}} # 配置动态分辨率
      onnx_pth = "test-conv-relu.onnx"# 传入原model,采用默认trace方式捕获模型,需要运行模型
      torch.onnx.export(net, dummy_input, onnx_pth, input_names=['input'], output_names=['output'], dynamic_axes=dynamic_axes)
      # 也可传入 scriptModule
      #net_script= torch.jit.script(test_net())
      # 需要外加配置 example_outputs,用来获取输出的shape和dtype,无需运行模型
      #torch.onnx.export(net_script, dummy_input, onnx_pth, input_names=['input'], output_names=['output'], dynamic_axes=dynamic_axes, example_outputs=[torch_out])# ort run
      oxx_m = ort.InferenceSession(onnx_pth)
      onnx_blob = dummy_input.data.numpy()
      onnx_out = oxx_m.run(None, {'input':onnx_blob})[0]dummy_input2 = torch.randn(1,3,3,imgh*2, imgw*2)
      onnx_blob2 = dummy_input2.data.numpy()
      onnx_out2 = oxx_m.run(None, {'input':onnx_blob2})[0]# opencv run
      #cv_m = cv2.dnn.readNet(onnx_pth)print('mean diff = ', np.mean(onnx_out - torch_out.data.numpy()))
      

    除了 len 会导致 trace 错误,其他几个也会导致 trace 出现问题:

    • .item() 会在 trace 过程中将 tensors 转为 int/float

    • 任何将 torch 类型转为 numpy/python 类型的代码

    • 一些有问题的算子,例如 advanced indexing

    • torch.jit.trace 不会对传入的 device 生效

      import torch
      def f(x):return torch.arange(x.shape[0], device=x.device)
      m = torch.jit.trace(f, torch.tensor([3]))
      print(m.code)
      # 输出
      # def f(x: Tensor) -> Tensor:
      #   _0 = ops.prim.NumToTensor(torch.size(x, 0))
      #   _1 = torch.arange(annotate(number, _0), dtype=None, layout=None, device=torch.device("cpu"), pin_memory=False)
      #   return _1
      print(m(torch.tensor([3]).cuda()).device)
      # 输出:device(type='cpu')
      

      trace 不会对传入的 cuda device 生效。

2.2 jit trace 验证技巧

为了保证trace的正确,我们可以通过一下的一些方法来尽量保证 trace 后的模型不会出错:
1.注意 warnings 信息。类似这样的:

TracerWarning: Using len to get tensor shape might cause the trace to be incorrect. Recommended usage would be tensor.shape[0]. Passing a tensor of different shape might lead to errors or silently give incorrect results.

TraceWarnings信息,它会造成模型的结果有可能不正确,但是它只是个 warning 等级。
2. 做单元测试。需要验证一下 eager mode 的模型输出与 trace 后的模型输出是否一致。

assert allclose(torch.jit.trace(model, input1)(input2), model(input2))
  1. 避免一些特殊的情况。例如下面的代码:
if x.numel() > 0:output = self.layers(x)
else:output = torch.zeros((0, C, H, W))  # 会创建一个空的输出

避免一些特殊情况比如空的输入输出之类的。

  1. 注意shape的使用。前面提到,tensor.size()在trace过程中会返回Tensor类型的数据,Tensor类型会在计算过程中被添加到计算图中,应该避免将Tensor类型的shape转为了常量。主要注意以下两点:
  • 使用 torch.size(0) 来代替 len(tensor),因为 torch.size(0) 返回的是 Tensor,len(tensor) 返回的是 int。对于自定义类,实现一个 .size 方法或者使用 .__len__() 方法来代替 len() ,例如这个例子
  • 不要使用 int() 或者 torch.as_tensor 来转换 size 的类型,因为这些操作也会被视为常量。
  1. 混合 tracing 和 scripting 方法。可以使用 torch.jit.script 来转换一些 torch.jit.trace 不能搞定的小的代码片段,混合使用 tracing 和 scripting,基本可以解决所有的问题。

2.3 混合使用 trace 和 script

trace 和 script 都有他们的问题,混合使用可以解决大部分问题。但是为了尽可能减小对于代码质量的负面影响,大部分情况下,都应该使用 torch.jit.trace,必要时才使用 torch.jit.script

  1. 在使用 torch.jit.trace 时,使用 @script_if_tracing 装饰器可以让被装饰的函数使用 script 方式进行编译

    def forward(self, ...):# ... some forward logic@torch.jit.script_if_tracingdef _inner_impl(x, y, z, flag: bool):# use control flow, etc.return ...output = _inner_impl(x, y, z, flag)# ... other forward logic
    

    但是使用 @script_if_tracing 时,需要保证函数中没有 PyTorch 的 modules,如果有的话,需要做一些修改,例如下面的:

    # 因为代码中有self.layers(),是一个pytorch的module,因此不能使用@script_if_tracing
    if x.numel() > 0:x = preprocess(x)output = self.layers(x)
    else:# Create empty outputsoutput = torch.zeros(...)
    

    这里需要做如下修改:

    # 需要将self.layers移出if判断,这时候可以用@script_if_tracing
    if x.numel() > 0:x = preprocess(x)
    else:# Create empty inputsx = torch.zeros(...)
    # 需要将self.layers()修改为支持empty的输入,或者将原先的条件判断加入到self.layers中
    output = self.layers(x)
    
  2. 合并多次 trace 的结果
    使用 torch.jit.script 生成的模型相比使用 torch.jit.trace 有两个好处:

    • 可以使用条件控制流,例如模型中使用一个 bool 值来控制 forward 的 flow,在 traced modules 里面是不支持的
    • 使用 traced module,只能有一个 forward() 函数,但是使用 scripted module,可以有多个前向计算的函数
    class Detector(nn.Module):do_keypoint: booldef forward(self, img):box = self.predict_boxes(img)if self.do_keypoint:kpts = self.predict_keypoint(img, box)@torch.jit.exportdef predict_boxes(self, img): pass@torch.jit.exportdef predict_keypoint(self, img, box): pass
    

    对于这种有 bool 值的控制流,除了使用 script,还可以多次进行 trace,然后将结果合并。

    det1 = torch.jit.trace(Detector(do_keypoint=True), inputs)
    det2 = torch.jit.trace(Detector(do_keypoint=False), inputs)
    

    然后将他们的 weight 复制一遍,并合并两次 trace 的结果:

    det2.submodule.weight = det1.submodule.weight
    class Wrapper(nn.ModuleList):def forward(self, img, do_keypoint: bool):if do_keypoint:return self[0](img)else:return self[1](img)
    exported = torch.jit.script(Wrapper([det1, det2]))
    

    对于这种有 bool 值的控制流,除了使用 script,还可以多次进行 trace,然后将结果合并。

    det1 = torch.jit.trace(Detector(do_keypoint=True), inputs)
    det2 = torch.jit.trace(Detector(do_keypoint=False), inputs)
    

    然后将他们的 weight 复制一遍,并合并两次 trace 的结果:

    det2.submodule.weight = det1.submodule.weight
    class Wrapper(nn.ModuleList):def forward(self, img, do_keypoint: bool):if do_keypoint:return self[0](img)else:return self[1](img)
    exported = torch.jit.script(Wrapper([det1, det2]))
    

2.4 trace 和 script 的性能

trace 总是会比 script 生成一样或者更简单的计算图,因此性能会更好一些。因为 script 会完整的表达 Python 代码的逻辑,甚至一些不必要的代码也会如实表达。例如下面的例子:

class A(nn.Module):def forward(self, x1, x2, x3):z = [0, 1, 2]xs = [x1, x2, x3]for k in z: x1 += xs[k]return x1
model = A()
print(torch.jit.script(model).code)
# def forward(self, x1: Tensor, x2: Tensor, x3: Tensor) -> Tensor:
#   z = [0, 1, 2]
#   xs = [x1, x2, x3]
#   x10 = x1
#   for _0 in range(torch.len(z)):
#     k = z[_0]
#     x10 = torch.add_(x10, xs[k])
#   return x10
print(torch.jit.trace(model, [torch.tensor(1)] * 3).code)
# def forward(self, x1: Tensor, x2: Tensor, x3: Tensor) -> Tensor:
#   x10 = torch.add_(x1, x1)
#   x11 = torch.add_(x10, x2)
#   return torch.add_(x11, x3)

2.5 总结

trace 具有明显的局限性:这篇文章的大部分篇幅都在谈论 trace 的局限性以及如何解决这些问题。实际上,这正是 trace 的优势所在:它有明确的局限性(和解决方案),因此你可以推理它是否有效。

相反,script 更像是一个黑盒子:在尝试之前,没有人知道它是否有效。文章中没有提到如何修复 script 的任何诀窍:有很多诀窍,但不值得你花时间去探究和修复一个黑盒子。

trace 和 script 都会影响代码的编写方式,但 trace 因为我们明确它的要求,对我们原始的代码造成的一些修改也不会太严重:

  • 它限制了输入/输出格式,但仅限于最外层的模块。(如上所述,这个问题可以通过一个wrapper解决)。
  • 它需要修改一些代码才能通用(例如在 trace 时添加一些 script),但这些修改只涉及受影响模块的内部实现,而不是它们的接口。

3. LibTorch 的使用

在得到所需模型后,可以尝试在 C++ 环境下使用得到的模型,这里就用到了 LibTorch。

3.1 LibTorch 的链接

结合自己环境的 CUDA 版本,去官网下载对应版本的 libTorch。例如 CUDA 版本为 11.1,则需要在下载地址中找到 libtorch-cxx11-abi-shared-with-deps-1.9.1%2Bcu111.zip 进行下载。

链接进需要再 cmake 内加上这几行即可:

set(TORCH_PATH "/home/yj/libtorch/share/cmake/Torch")
message("TORCH_PATH set to: ${TORCH_PATH}")
set(Torch_DIR ${TORCH_PATH})find_package(Torch REQUIRED)
message(STATUS "Torch version is: ${Torch_VERSION}")# <target> is your target's name
target_link_libraries(<target> ${TORCH_LIBRARIES}
)

3.2 接口和实现

  1. 头文件引入 :

    #include <torch/script.h>
    #include <torch/torch.h>
    
  2. 加载模型

    module = torch::jit::load(PATH);
    
  3. 函数实现

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/243782.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

了解云工作负载保护:技术和最佳实践

云工作负载是指云环境中的应用程序或存储元素&#xff0c;无论是公共云、私有云还是混合云。每个云工作负载都使用云的资源&#xff0c;包括计算、网络和存储。 云工作负载可以多种多样&#xff0c;例如运行应用程序、数据库或托管网站。它们可以是静态的或动态的&#xff0c;…

mysql 导入数据 1273 - Unknown collation: ‘utf8mb4_0900_ai_ci‘

前言: mysql 导入数据 遇到这个错误 1273 - Unknown collation: utf8mb4_0900_ai_ci 具体原因没有深究 但应该是设计数据库的 字符集类型会出现这个问题 例如: char varchar text..... utf8mb4 类型可以存储表情 在现在这个时代会用很多 以后会用的更多 所以不建议改…

LabVIEW滚动轴承故障在线监测

展示了如何将LabVIEW开发出一种有效的滚动轴承故障在线监测系统。介绍了该系统的开发过程、工作原理及其在实际应用中的效果。该系统成功地应用于对滚动轴承故障的早期诊断&#xff0c;提高了故障检测的准确性和效率。 滚动轴承在工作过程中会产生复杂的振动信号&#xff0c;包…

【C语言进阶】预处理详解

引言 对预处理的相关知识进行详细的介绍 ✨ 猪巴戒&#xff1a;个人主页✨ 所属专栏&#xff1a;《C语言进阶》 &#x1f388;跟着猪巴戒&#xff0c;一起学习C语言&#x1f388; 目录 引言 预定义符号 #define定义常量 #define定义宏 带有副作用的宏参数 宏替换的规则 …

开源的测试平台快2千星了,能带来多少收益呢

最近看了下自己去年初开源的测试平台&#xff0c;star一起算的话也到1.7k了&#xff1a; 做开源的初心一方面是想把自己的理解和思想展示出来&#xff0c;另一方面是想进一步打造个人IP&#xff0c;提升影响力&#xff08;其实这个想法很早之前就有了&#xff0c;计划过无数次但…

【BERT】详解

BERT 简介 BERT 是谷歌在 2018 年时提出的一种基于 Transformer 的双向编码器的表示学习模型&#xff0c;它在多个 NLP 任务上刷新了记录。它利用了大量的无标注文本进行预训练&#xff0c;预训练任务有掩码语言模型和下一句预测&#xff0c;掩码语言模型指的是随机地替换文本中…

「云渲染C4D」C4D如何进行云渲染?

云渲染C4D的过程可现实一键式完成&#xff0c;目前云渲染平台随着技术的发展&#xff0c;平台的使用越发容易操作&#xff0c;无论是渲染文件的传输性、安全性、高效性都有较大的提升&#xff0c;本次为大家简单说明下关于云渲染操作方法。 &#xff08;图源网络&#xff09; …

Android状态栏布局隐藏的方法

1.问题如下&#xff0c;安卓布局很不协调 2.先将ActionBar设置为NoActionBar 先打开styles.xml 3.使用工具类 package com.afison.newfault.utils;import android.annotation.TargetApi; import android.app.Activity; import android.content.Context; import android.graph…

“深入理解 Docker 和 Nacos 的单个部署与集成部署“

目录 引言&#xff1a;Docker Nacos 单个部署1.1 什么是 Docker&#xff1f;Docker 的概念和工作原理Docker 为什么受到广泛应用和认可 1.2 什么是 Nacos&#xff1f;Nacos 的核心功能和特点Nacos 在微服务架构中的作用 1.3 Docker 单个部署 Nacos Docker Nacos 集成部署总结&a…

如何使用固定公网地址访问多个本地Nginx服务搭建的网站

文章目录 1. 下载windows版Nginx2. 配置Nginx3. 测试局域网访问4. cpolar内网穿透5. 测试公网访问6. 配置固定二级子域名7. 测试访问公网固定二级子域名 本文主要介绍如何在Windows系统对Nginx进行配置&#xff0c;并结合cpolar内网穿透工具实现固定公网地址远程访问多个本地站…

Spring Boot整合Redis的高效数据缓存实践

引言 在现代Web应用开发中&#xff0c;数据缓存是提高系统性能和响应速度的关键。Redis作为一种高性能的缓存和数据存储解决方案&#xff0c;被广泛应用于各种场景。本文将研究如何使用Spring Boot整合Redis&#xff0c;通过这个强大的缓存工具提高应用的性能和可伸缩性。 整合…

操作系统导论-课后作业-ch14

1. 代码如下&#xff1a; #include <stdio.h> #include <stdlib.h>int main() {int *i NULL;free(i);return 0; }执行结果如下&#xff1a; 可见&#xff0c;没有任何报错&#xff0c;执行完成。 2. 执行结果如下&#xff1a; 3. valgrind安装使用参考&a…

接口自动化测试(Python+Requests+Unittest)

(1)接口自动化测试的意义、前后端分离思想 接口自动化测试的优缺点&#xff1a; 优点&#xff1a; 1、测试复用性。 2、维护成本相对UI自动化低一些。 为什么UI自动化维护成本更高&#xff1f; 因为前端页面变化太快&#xff0c;而且UI自动化比较耗时&#xff08;比如等待页…

常见PCB封装

表面贴片封装 通孔封装 公众号 | FunIO 微信搜一搜 “funio”&#xff0c;发现更多精彩内容。 个人博客 | blog.boringhex.top

基于神经网络的电力系统的负荷预测

一、背景介绍&#xff1a; 电力系统负荷预测是生产部门的重要工作之一&#xff0c;通过准确的负荷预测&#xff0c;可以经济合理地安排机组的启停、减少旋转备用容量、合理安排检修计划、降低发电成本和提高经济效益。负荷预测按预测的时间可以分为长期、中期和短期负荷预测。…

使用pysimplegui+opencv编写一个摄像头的播放器

需求 使用pysimplegui和opencv实现一个播放器&#xff0c;播放 摄像头的画面。 代码实现 import cv2 import time from typing import Iterable, NamedTuple, Optionalimport PySimpleGUI as sgclass CameraSpec(NamedTuple):name: strindex: intwidth: intheight: intfps: i…

记一次 stackoverflowerror 线上排查过程

一.线上 stackOverFlowError xxx日,突然收到线上日志关键字频繁告警 classCastException.从字面上的报警来看,仅仅是类型转换异常,查看细则发现其实是 stackOverFlowError.很多同学面试的时候总会被问到有没有遇到过线上stackOverFlowError?有么有遇到栈溢出?具体栈溢出怎么来…

网络爬虫采集工具

在当今数字化的时代&#xff0c;获取海量数据对于企业、学术界和个人都至关重要。网络爬虫成为一种强大的工具&#xff0c;能够从互联网上抓取并提取所需的信息。本文将专心分享关于网络爬虫采集数据的全面指南&#xff0c;深入探讨其原理、应用场景以及使用过程中可能遇到的挑…

【论文阅读笔记】Swin-Unet: Unet-like Pure Transformer for Medical Image Segmentation

1.介绍 Swin-Unet: Unet-like Pure Transformer for Medical Image Segmentation Swin-Unet&#xff1a;用于医学图像分割的类Unet纯Transformer 2022年发表在 Computer Vision – ECCV 2022 Workshops Paper Code 2.摘要 在过去的几年里&#xff0c;卷积神经网络&#xff…

OTA 升级软件推荐,附带MD5,CRC16,CRC32,AES算法工具

说明&#xff1a;推荐 OTA 工具软件&#xff0c;可以通过串口按 OTA 协议发送 bin 文件给 MCU,完成 bootloader 升级app 功能 , 这个软件 附带提供 MD5,CRC16,CRC32,AES 算法工具。 文档持续完善中... 1. OTA界面 2.AES.MD5.CRC界面 3.下载链接&#xff1a; 链接: https://p…