pytorch 模型的保存与加载方法以及使用onnx模型部署推理

前言

在pytorch中,模型的保存和加载是一个比较麻烦的点,因为Pytorch保存模型的pkl格式中仅能记住每一层对应的参数,以及对应的模型结构所保存的位置也即是class所存放的位置,假如class位置改变的话,那么模型加载的时候就会报很多错误,大概的意思就是类存放的位置发生了改变。那么如何改变这一现状呢?
首先对于模型的重新加载和训练,我们可以使用pkl文件格式来进行,因为这一阶段的话并不涉及到模型结构存储文件的转移。但是到了推理部署阶段的时候,我们就必须把模型的结构以及权重保存下来,否则会给我们后面的工作带来很大的麻烦,而将pytorch的模型转为onnx模型进行推理部署就是个非常不错的方法。
核心的思想就是将pytorch的模型进行一次前向传播,然后通过输出的张量的梯度传播方式进行静态图模式的数据流传播方式的推导,然后构建静态图模型,构建完毕后将包含数据的整个静态图进行存储。

pkl模型保存加载方法

只保存参数

1
2
3
4
5
6
7
8
#-----------------------------------------保存
torch.save(net.cpu().state_dict(),path+'/xxxx.pkl')
#-----------------------------------------加载      
model_state_dict = torch.load(path+'/xxxx.pkl')
net = Model()
#strict表示不严格将参数进行对应
#如果为True那么模型中各层的名称发生变化的时候会无法载入模型权重
net.load_state_dict(model_state_dict , strict=False)

只保存checkpoint的各项参数,这个是比较推荐的方法,这个方法可以督促你记住在python中采用pickle库所保存的class必须有原有的类才能将其构建数据的方式进行还原

1
2
3
4
5
6
7
8
9
10
11
12
#-----------------------------------------保存
torch.save({'model_state_dict':net.cpu().state_dict(),
            'model_Sequential':list(net.children()),#只对序贯模型有用
            'optimizer_state_dict': optimizer.state_dict()},
            path+'/xxxx.pkl')
#-----------------------------------------加载      
checkpoint = torch.load(path+'/xxxx.pkl')
#如果是序贯模型的话可以按照下列方法得到模型的结构
net = nn.Sequential(*checkpoint['model_Sequential'])
#net = Model() 大部分情况用这个加载原有模型
net.load_state_dict(checkpoint['model_state_dict'] , strict=False)
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

保存参数以及class存放的位置

1
2
3
4
#-----------------------------------------保存
torch.save(net.cpu(),path+'/xxxx.pkl')
#-----------------------------------------加载  
net = torch.load(path+'/xxxx.pkl')

ONNX

https://www.jianshu.com/p/65cfb475584a

Open Neural Network Exchange(简称 ONNX,开放神经网络交换)格式,是一个用于表示深度学习模型的标准,可使模型在不同框架之间进行转移。
ONNX 是一种针对机器学习所设计的开放式的文件格式,用于存储训练好的模型。它使得不同的人工智能框架(如 Pytorch, MXNet)可以采用相同格式存储模型数据并交互。ONNX 的规范及代码主要由微软,亚马逊 ,Facebook 和 IBM 等公司共同开发,以开放源代码的方式托管在 Github 上。目前官方支持加载 ONNX 模型并进行推理的深度学习框架有:Caffe2, PyTorch, MXNet,ML.NET,TensorRT 和 Microsoft CNTK,并且 TensorFlow 也非官方的支持ONNX。
ONNX Runtime 是针对 ONNX 模型的以性能为中心的引擎,可在多个平台和硬件 (Windows,Linux 和 Mac 以及 CPU 和 GPU 上) 高效地进行推理。ONNX Runtime 可大大提高多个模型的性能。

准备

首先得安装onnx的库

1
2
3
pip install onnx
pip install onnxruntime
pip install onnxruntime-gpu

onnx模型的保存、加载与推理

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import numpy as np
import torch
import onnx
import onnxruntime

#-------------------------------------保存模型
#随机生成一个符合模型的输入张量
x = torch.randn(1,100,13, requires_grad=True)
net.eval()
torch_out = net(x)
torch.onnx.export(net,
                  x,
                  "xxxx.onnx",
                  export_params=True,
                  opset_version=11,
                  do_constant_folding=True,
                  input_names = ['input'],
                  output_names = ['output'],
                  #dynamic_axes 表示的是动态输入的维度,数字表示序号 字符串表示名称
                  dynamic_axes={'input' : {0 : 'batch_size',1:'channel'},
                                'output' : {0 : 'batch_size',1:'channel'}})
#-------------------------------------模型的加载与推理
# 使用 ONNX 的 API 检查 ONNX 模型
onnx_model = onnx.load(xxxx.onnx")
onnx.checker.check_model(onnx_model)
# 使用 ONNX Runtime 运行模型
ort_session = onnxruntime.InferenceSession("xxxx.onnx")
def to_numpy(tensor):
    return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
 #构建输入并得到输出
ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(x)}
ort_outs = ort_session.run(None, ort_inputs)
ort_out = ort_outs[0]
# --------------------------------------比较ONNX Runtime 和 PyTorch 的结果
np.testing.assert_allclose(to_numpy(torch_out), ort_out, rtol=1e-03, atol=1e-05)
print("Exported model has been tested with ONNXRuntime, and the result looks good!")

总结

本篇文章主要总结了Pytorch模型的几种保存方法并且在此基础上讲解了onnx模型的推理部署方法,具体步骤在程序中有较为详细的注释,若大家还有其他疑问的可以在评论区里留言