记录一下最近遇到的ONNX动态输入问题
1. 一个tensor的动态输入数据
首先是使用到的onnx的torch.onnx.export()函数:
贴一下官方的代码示意地址:ONNX动态输入
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 | #首先我们要有个tensor输入,比如网络的输入是batch_size*1*224*224 x = torch.randn(batch_size, 1, 224, 224, requires_grad=True) #torch_model是模型的实例化 torch_out = torch_model(x) #下面是导出的主要函数 # Export the model torch.onnx.export(torch_model, # model being run x, # model input (or a tuple for multiple inputs) "super_resolution.onnx", # where to save the model (can be a file or file-like object) export_params=True, # store the trained parameter weights inside the model file opset_version=10, # the ONNX version to export the model to do_constant_folding=True, # whether to execute constant folding for optimization input_names = ['input'], # the model's input names output_names = ['output'], # the model's output names dynamic_axes={'input' : {0 : 'batch_size'}, # variable lenght axes 'output' : {0 : 'batch_size'}}) |
上面我们主要是设置dynamic_axes的相关属性,这个属性的Key是从input_names和output_names里面获取的,所以在这两个里面一定要有相关的属性值,否则会有warning。这里面的batch_size则为动态输入的值,当然我们也可以在外面设置dynamic的属性,比如下面:
1 2 | dynamic_axes = {'input' : {0 : 'batch_size'}, 'output' : {0 : 'batch_size'}} |
然后在外面将dynamic_axes=dynamic_axes 赋值一下就OK了。
2.多个tensor的动态输入问题
那么以上以只有一个tensor输入的情况我们进行的操作。
下面我们说一下有多个动态的tensor输入的情况下如何进行相关的操作:
比如下面的:
1 2 3 4 | pillar_x = torch.ones([1, 1, 9918, 100], dtype=torch.float32, device="cuda:0") pillar_y = torch.ones([1, 1, 9918, 100], dtype=torch.float32, device="cuda:0") pillar_z = torch.ones([1, 1, 9918, 100], dtype=torch.float32, device="cuda:0") pillar_i = torch.ones([1, 1, 9918, 100], dtype=torch.float32, device="cuda:0") |
加入上面的tensor是网络的一个输入,比如9918这个数据是我们网络输入的时候可能发生变化的情况,那么我们就要将其输入成动态输入的模式,详情如下:
1 2 3 4 5 6 7 8 9 10 | input = [pillar_x, pillar_y, pillar_z, pillar_i] dynamic_axes = {'Pillar_input_pillar_x':{2:'pillar_num'}, 'Pillar_input_pillar_y': {2: 'pillar_num'}, 'Pillar_input_pillar_z': {2: 'pillar_num'}, 'Pillar_input_pillar_i': {2: 'pillar_num'}, 'output_loss1':{0:'batch_size'}, 'output_loss2': {0: 'batch_size'}, 'output_loss3': {0: 'batch_size'}} torch.onnx.export(net,input,'test1.onnx',verbose=True,input_names=['Pillar_input_pillar_x','Pillar_input_pillar_y','Pillar_input_pillar_z', 'Pillar_input_pillar_i'], output_names=['output_loss1','output_loss2','output_loss3'],dynamic_axes=dynamic_axes) |
我们单拿一组数据来说:dynamic_axes字典里面的item就是我们要设置动态输入数据,key是我们要动态输入的某一项数据,val的值是这一项数据中的哪一维度要设置成动态的。
那么我们dynamic_axes中的数据是要在input_names和output_names找到的,其实这两项只是数据的别名,然后根据dynamic里面的名字找到我们需要设置的具体动态项,也就是我们可以设置动态的输入,也可以设置动态的输出。
有一点要强调的是input_names中的序列是和input输入的数据相对应的。
实际效果如下:
以上。
目前正在研究怎么使用trt进行动态的数据调用做inference工作,待补充。