简介
ONNX Runtime是一个用于ONNX(Open Neural Network Exchange)模型推理的引擎。微软联合Facebook等在2017年搞了个深度学习以及机器学习模型的格式标准--ONNX,顺路提供了一个专门用于ONNX模型推理的引擎,onnxruntime。目前ONNX Runtime 还只能跑在HOST端,不过官网也表示,对于移动端的适配工作也在进行中。
一半处于工作需要一半出于兴趣,决定阅读一下onnxruntime的源码。这里算个学习记录吧。
安装
ONNX Runtime 的GitHub仓库地址为 https://github.com/microsoft/onnxruntime 。编译安装过程可以参照GitHub上的说明,这里为了方便,直接选择了PyPi的安装源。执行
1 | pip install onnxruntime |
即完成了安装。需要注意的是只支持Python3。
开始
涉及文件
onnxruntimeonnxruntimepythonsession.py
onnxruntimeonnxruntimecoreframeworkutils.cc
onnxruntimeonnxruntimepythononnxruntime_pybind_state.cc
onnxruntimeonnxruntimecoresessioninference_session.cc
onnxruntimeonnxruntimecoresessioninference_session.h
代码入口
代码阅读需要先找到一个入口。通过onnxruntime的例子我们知道,在Python使用使用onnxruntime很简单,主要代码就三行:
1 2 3 | import onnxruntime sess = onnxruntime.InferenceSession('YouModelPath.onnx') output = sess.run([output_nodes], {input_nodes: x}) |
第一行导入onnxruntime模块;第二行创建一个
实例生成
ONNX Runtime的代码组织非常良好,我们很容易找到
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 | def __init__(self, path_or_bytes, sess_options=None, providers=[]): """ :param path_or_bytes: filename or serialized model in a byte string :param sess_options: session options :param providers: providers to use for session. If empty, will use all available providers. """ self._path_or_bytes = path_or_bytes self._sess_options = sess_options self._load_model(providers) self._enable_fallback = True def _load_model(self, providers=[]): if isinstance(self._path_or_bytes, str): self._sess = C.InferenceSession( self._sess_options if self._sess_options else C.get_default_session_options(), self._path_or_bytes, True) elif isinstance(self._path_or_bytes, bytes): self._sess = C.InferenceSession( self._sess_options if self._sess_options else C.get_default_session_options(), self._path_or_bytes, False) # elif isinstance(self._path_or_bytes, tuple): # to remove, hidden trick # self._sess.load_model_no_init(self._path_or_bytes[0], providers) else: raise TypeError("Unable to load from type '{0}'".format(type(self._path_or_bytes))) # 注意看下面这句话,后面我们还会回来详细讲 self._sess.load_model(providers) |
我们发现其实这里
门在这了,开门 的钥匙在哪儿?
我们进盯着Python中
1 2 3 | self._sess = C.InferenceSession( self._sess_options if self._sess_options else C.get_default_session_options(), self._path_or_bytes, True) |
这句话,它是全村的希望。通过这句话,我们知道,在
1 2 3 4 5 6 7 8 9 10 11 12 13 14 | py::class_<InferenceSession>(m, "InferenceSession", R"pbdoc(This is the main class used to run a model.)pbdoc") // In Python3, a Python bytes object will be passed to C++ functions that accept std::string or char* // without any conversion. So this init method can be used for model file path (string) // and model content (bytes) .def(py::init([](const SessionOptions& so, const std::string& arg, bool is_arg_file_name) { // Given arg is the file path. Invoke the corresponding ctor(). if (is_arg_file_name) { return onnxruntime::make_unique<InferenceSession>(so, arg, SessionObjectInitializer::Get()); } // Given arg is the model content as bytes. Invoke the corresponding ctor(). std::istringstream buffer(arg); return onnxruntime::make_unique<InferenceSession>(so, buffer, SessionObjectInitializer::Get()); })) |
欢迎来到C++。
1 2 3 | InferenceSession(const SessionOptions& session_options, const std::string& model_uri, logging::LoggingManager* logging_manager = nullptr); |
的构造函数。
这里有个奇怪的现象:
1 2 3 | if (is_arg_file_name) { return onnxruntime::make_unique<InferenceSession>(so, arg, SessionObjectInitializer::Get()); } |
中第三个参数我们通过查看
来看看
1 2 3 | InferenceSession(const SessionOptions& session_options, const std::string& model_uri, logging::LoggingManager* logging_manager = nullptr); |
的实现:
1 2 3 4 5 6 7 8 9 10 11 12 13 | InferenceSession::InferenceSession(const SessionOptions& session_options, const std::string& model_uri, logging::LoggingManager* logging_manager) : insert_cast_transformer_("CastFloat16Transformer") { model_location_ = ToWideString(model_uri); model_proto_ = onnxruntime::make_unique<ONNX_NAMESPACE::ModelProto>(); auto status = Model::Load(model_location_, *model_proto_); ORT_ENFORCE(status.IsOK(), "Given model could not be parsed while creating inference session. Error message: ", status.ErrorMessage()); // Finalize session options and initialize assets of this session instance ConstructorCommon(session_options, logging_manager); } |
这里主要就做了三件事:
- 将模型地址保存在类成员变量
model_location_ 中; - 将模型二进制内容保存在类成员变量
model_proto_ ; - 调用
ConstructorCommon 完成剩余的工作。
ConstructorCommon 中做些环境检查,准备log输出等工作。其中最主要的是,是创建了一个SessionState 实例session_state_ ,这是类成员变量,其中打包了为运行这个模型所需要的线程池、模型结构、provider等信息。至于什么是Provider,其实就是模型所跑的硬件,比如是CPU还是GPU,到了这里其实session_state_ 里面很多信息还没完备,例如模型结构并未保存,Provider还只是个壳,里面并没有保存任何硬件信息,还需要一个初始化阶段。至此,InferenceSession 实例创建完毕。
初始化
又回到最初的起点,Python代码开始的地方,最后一句
1 2 3 4 5 6 | .def( "load_model", [](InferenceSession* sess, std::vector<std::string>& provider_types) { OrtPybindThrowIfError(sess->Load()); InitializeSession(sess, provider_types); }, R"pbdoc(Load a model saved in ONNX format.)pbdoc") |
- 将模型二进制内容解析;
- 选择模型运行方式,并行还是串行;
- 选择模型Provider,如果用户没有指定Provider,就把目前运行环境中支持的硬件都注册,比如GPU,CPU等,并且保证CPU一定可用;
- 确定模型中各个节点的运行先后顺序。
这里先不细说了,只需要知道它是按照ONNX标准将二进制数据解析成一个图并将它存储在session_stat_ 中就可以了。以后再详细说。经过这一步之后,session_state_ 已经完备,到达神装,可以随时开战。
运行
经过初始化之后,一切就绪。我们直接看C++中
1 2 3 | Status InferenceSession::Run(const RunOptions& run_options, const std::vector<std::string>& feed_names, const std::vector<OrtValue>& feeds, const std::vector<std::string>& output_names, std::vector<OrtValue>* p_fetches) |
的这个。在这里,
1 2 3 | utils::ExecuteGraph(*session_state_, feeds_fetches_manager, feeds, *p_fetches, session_options_.execution_mode, run_options.terminate, run_logger)) |
而
总结
一个大概流程就是通过使用pybind11将C++接口暴露给Python,Python经过简单封装后提供给用户直接使用。上面有几个关键点值得深入研究:
- 模型节点执行顺序的确定;
- 模型节点Provider的选择;
- 模型解析过程;
- 模型推理详细过程;
- 模型如何高效推理。
最后,一图胜千言:onnxruntime_exec_flow.png