ONNX
ONNX(Open Neural Network Exchange)是一个开放的格式,用于表示机器学习模型。它使得不同框架之间的模型可以互操作,方便模型的迁移和部署。以下是一些关于 ONNX 的基本介绍和使用方法。
- 模型转换:ONNX 允许你将模型从一个深度学习框架(如 PyTorch、TensorFlow)转换为 ONNX 格式。
- 互操作性:ONNX 模型可以在支持 ONNX 的不同平台和工具之间共享。
- 优化:ONNX 提供了工具来优化模型,以提高推理性能。
将模型转换为 ONNX 格式
以下是将 PyTorch 模型转换为 ONNX 模型的步骤:
- 安装 ONNX
安装了 ONNX 和相关的转换工具:
pip install onnx
pip install onnxruntime # 用于运行 ONNX 模型
pip install torch # PyTorch
- 转换 PyTorch 模型
一个已训练的 PyTorch 模型,可以使用以下代码将其转换为 ONNX 格式:
import torch
import torch.onnx
import torchvision.models as models# 加载预训练的 PyTorch 模型
model = models.resnet18(pretrained=True)
model.eval() # 设置模型为推理模式# 创建示例输入张量
dummy_input = torch.randn(1, 3, 224, 224)# 将模型导出为 ONNX 格式
torch.onnx.export(model, dummy_input, "resnet18.onnx", verbose=True)
在这个示例中,将一个预训练的 ResNet-18 模型转换为 ONNX 格式并保存为 resnet18.onnx
文件。
加载和运行 ONNX 模型
使用 ONNX Runtime 来加载和运行转换后的 ONNX 模型:
import onnx
import onnxruntime as ort
import numpy as np# 加载 ONNX 模型
onnx_model = onnx.load("resnet18.onnx")
onnx.checker.check_model(onnx_model) # 检查模型是否有效# 创建 ONNX Runtime 会话
ort_session = ort.InferenceSession("resnet18.onnx")# 创建输入数据
dummy_input = np.random.randn(1, 3, 224, 224).astype(np.float32)# 运行模型
outputs = ort_session.run(None, {"input": dummy_input})
print(outputs[0])
检查和优化 ONNX 模型
ONNX 提供了一些工具来检查和优化模型:
1. 检查模型
使用 onnx.checker
来验证模型的有效性:
import onnxonnx_model = onnx.load("resnet18.onnx")
onnx.checker.check_model(onnx_model)
2. 优化模型
使用 onnx.optimizer
来优化模型:
import onnx
import onnx.optimizeronnx_model = onnx.load("resnet18.onnx")# 定义优化通道
passes = ["fuse_consecutive_transposes", "eliminate_deadend"]# 优化模型
optimized_model = onnx.optimizer.optimize(onnx_model, passes)# 保存优化后的模型
onnx.save(optimized_model, "resnet18_optimized.onnx")
其他常用工具和库
- Netron:用于可视化 ONNX 模型的工具。可以下载并使用 Netron 打开 .onnx 文件进行模型可视化。
- ONNX Model Zoo:ONNX 模型库,包含许多预训练的 ONNX 模型,可以直接下载和使用。
小结
ONNX 作为一个开放的模型格式,可以极大地提高模型在不同框架和平台之间的可移植性。通过学习如何将模型转换为 ONNX 格式,并使用 ONNX Runtime 进行推理和优化,你可以更高效地部署和管理你的机器学习模型。
只有一个元素的时候才能够使用item()
转为scalar,无论是一个0维度张量,还是1维张量,还是2维度
x_t = torch.tensor([1.0])
x2_t =torch.tensor(1.0)
x4_t = torch.tensor([[[1.0]]])x_n = x_t.item() # 1.0
x2_n = x2_t.item() # 1.0
x3_n = x3_t.item() # 1.0