ONNX Runtime 源码阅读:模型推理过程概览

简介

ONNX Runtime是一个用于ONNX(Open Neural Network Exchange)模型推理的引擎。微软联合Facebook等在2017年搞了个深度学习以及机器学习模型的格式标准--ONNX,顺路提供了一个专门用于ONNX模型推理的引擎,onnxruntime。目前ONNX Runtime 还只能跑在HOST端,不过官网也表示,对于移动端的适配工作也在进行中。
一半处于工作需要一半出于兴趣,决定阅读一下onnxruntime的源码。这里算个学习记录吧。

安装

ONNX Runtime 的GitHub仓库地址为 https://github.com/microsoft/onnxruntime 。编译安装过程可以参照GitHub上的说明,这里为了方便,直接选择了PyPi的安装源。执行

1
pip install onnxruntime

即完成了安装。需要注意的是只支持Python3。

开始

涉及文件

onnxruntimeonnxruntimepythonsession.py
onnxruntimeonnxruntimecoreframeworkutils.cc
onnxruntimeonnxruntimepythononnxruntime_pybind_state.cc
onnxruntimeonnxruntimecoresessioninference_session.cc
onnxruntimeonnxruntimecoresessioninference_session.h

代码入口

代码阅读需要先找到一个入口。通过onnxruntime的例子我们知道,在Python使用使用onnxruntime很简单,主要代码就三行:

1
2
3
import onnxruntime
sess = onnxruntime.InferenceSession('YouModelPath.onnx')
output = sess.run([output_nodes], {input_nodes: x})

第一行导入onnxruntime模块;第二行创建一个InferenceSession的实例并传给它一个模型地址;第三行调用run方法进行模型推理。因此onnxruntime模块中的InferenceSession就是我们的切入点。

实例生成

ONNX Runtime的代码组织非常良好,我们很容易找到InferenceSession所在文件session.py,整个文件非常简单,就只定义了一个InferenceSession类。通过阅读InferenceSession__init__函数,

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
    def __init__(self, path_or_bytes, sess_options=None, providers=[]):
        """
        :param path_or_bytes: filename or serialized model in a byte string
        :param sess_options: session options
        :param providers: providers to use for session. If empty, will use
            all available providers.
        """
        self._path_or_bytes = path_or_bytes
        self._sess_options = sess_options
        self._load_model(providers)
        self._enable_fallback = True

    def _load_model(self, providers=[]):
        if isinstance(self._path_or_bytes, str):
            self._sess = C.InferenceSession(
                self._sess_options if self._sess_options else C.get_default_session_options(),
                self._path_or_bytes, True)
        elif isinstance(self._path_or_bytes, bytes):
            self._sess = C.InferenceSession(
                self._sess_options if self._sess_options else C.get_default_session_options(),
                self._path_or_bytes, False)
        # elif isinstance(self._path_or_bytes, tuple):
            # to remove, hidden trick
        #   self._sess.load_model_no_init(self._path_or_bytes[0], providers)
        else:
            raise TypeError("Unable to load from type '{0}'".format(type(self._path_or_bytes)))
        # 注意看下面这句话,后面我们还会回来详细讲
        self._sess.load_model(providers)

我们发现其实这里InferenceSession只不过是一个壳,所有工作都委托给了C.InferenceSessionC从导入语句from onnxruntime.capi import _pybind_state as C可知其实就是一个C语言实现的Python接口,其源码在onnxruntimeonnxruntimepythononnxruntime_pybind_state.cc中。onnxruntime_pybind_state.cc是将C++代码暴露给Python的一个接口,就像是一个门,代码经过这里,就从Python进入了C++的世界。
门在这了,开门 的钥匙在哪儿?
我们进盯着Python中

1
2
3
 self._sess = C.InferenceSession(
                self._sess_options if self._sess_options else C.get_default_session_options(),
                self._path_or_bytes, True)

这句话,它是全村的希望。通过这句话,我们知道,在onnxruntime_pybind_state.cc应该会定义有一个类,名叫InferenceSession,一顿操作猛如虎,定位到InferenceSession定义的地方:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
py::class_<InferenceSession>(m, "InferenceSession", R"pbdoc(This is the main class used to run a model.)pbdoc")
      // In Python3, a Python bytes object will be passed to C++ functions that accept std::string or char*
      // without any conversion. So this init method can be used for model file path (string)
      // and model content (bytes)
      .def(py::init([](const SessionOptions& so, const std::string& arg, bool is_arg_file_name) {
        // Given arg is the file path. Invoke the corresponding ctor().
        if (is_arg_file_name) {
          return onnxruntime::make_unique<InferenceSession>(so, arg, SessionObjectInitializer::Get());
        }

        // Given arg is the model content as bytes. Invoke the corresponding ctor().
        std::istringstream buffer(arg);
        return onnxruntime::make_unique<InferenceSession>(so, buffer, SessionObjectInitializer::Get());
      }))

欢迎来到C++。def(py::init([](const SessionOptions& so, const std::string& arg, bool is_arg_file_name)实现了类似Python中__init__的功能,其根据传入的模型参数类型(模型的地址还是模型的数据流),调用C++中的类InferenceSession的相应构造函数构造一个的实例,然后将这个实例的指针返回给Python。由于我们例子中传入的是模型的地址字符串,因此我们需要找到的是签名类型为:

1
2
3
InferenceSession(const SessionOptions& session_options,
                   const std::string& model_uri,
                   logging::LoggingManager* logging_manager = nullptr);

的构造函数。
这里有个奇怪的现象:

1
2
3
  if (is_arg_file_name) {
          return onnxruntime::make_unique<InferenceSession>(so, arg, SessionObjectInitializer::Get());
        }

中第三个参数我们通过查看SessionObjectInitializer::Get()获取到的是类SessionObjectInitializer的一个实例,但是InferenceSession对应的构造函数对应为所需要的是一个logging::LoggingManager的指针,对不上,咋整?我们知道C++可不像Python,C++是强类型的语言,不将就。这里作者用了个小技巧,他为SessionObjectInitializer定义了两个类型转换函数,让编译器帮他转到所需要的类型,这里编译器会将SessionObjectInitializer转换成logging::LoggingManager*
来看看

1
2
3
InferenceSession(const SessionOptions& session_options,
                   const std::string& model_uri,
                   logging::LoggingManager* logging_manager = nullptr);

的实现:

1
2
3
4
5
6
7
8
9
10
11
12
13
InferenceSession::InferenceSession(const SessionOptions& session_options,
                                   const std::string& model_uri,
                                   logging::LoggingManager* logging_manager)
    : insert_cast_transformer_("CastFloat16Transformer") {
  model_location_ = ToWideString(model_uri);
  model_proto_ = onnxruntime::make_unique<ONNX_NAMESPACE::ModelProto>();
  auto status = Model::Load(model_location_, *model_proto_);
  ORT_ENFORCE(status.IsOK(), "Given model could not be parsed while creating inference session. Error message: ",
              status.ErrorMessage());

  // Finalize session options and initialize assets of this session instance
  ConstructorCommon(session_options, logging_manager);
}

这里主要就做了三件事:

  1. 将模型地址保存在类成员变量model_location_中;
  2. 将模型二进制内容保存在类成员变量model_proto_;
  3. 调用ConstructorCommon完成剩余的工作。
    ConstructorCommon中做些环境检查,准备log输出等工作。其中最主要的是,是创建了一个SessionState实例session_state_,这是类成员变量,其中打包了为运行这个模型所需要的线程池、模型结构、provider等信息。至于什么是Provider,其实就是模型所跑的硬件,比如是CPU还是GPU,到了这里其实session_state_里面很多信息还没完备,例如模型结构并未保存,Provider还只是个壳,里面并没有保存任何硬件信息,还需要一个初始化阶段。至此,InferenceSession实例创建完毕。

初始化

又回到最初的起点,Python代码开始的地方,最后一句self._sess.load_model(providers),其实现如下:

1
2
3
4
5
6
.def(
          "load_model", [](InferenceSession* sess, std::vector<std::string>& provider_types) {
            OrtPybindThrowIfError(sess->Load());
            InitializeSession(sess, provider_types);
          },
          R"pbdoc(Load a model saved in ONNX format.)pbdoc")

load_model主要做了一下事情:

  1. 将模型二进制内容解析;
  2. 选择模型运行方式,并行还是串行;
  3. 选择模型Provider,如果用户没有指定Provider,就把目前运行环境中支持的硬件都注册,比如GPU,CPU等,并且保证CPU一定可用;
  4. 确定模型中各个节点的运行先后顺序。
    这里先不细说了,只需要知道它是按照ONNX标准将二进制数据解析成一个图并将它存储在session_stat_中就可以了。以后再详细说。经过这一步之后,session_state_已经完备,到达神装,可以随时开战。

运行

经过初始化之后,一切就绪。我们直接看C++中InferenceSessionrun方法好了,因为通过前面知道,在Python中的操作最终都会调用到C++的代码来执行实际的内容。虽然InferenceSession重载了很多run方法,但是最终都会辗转调用到签名为

1
2
3
Status InferenceSession::Run(const RunOptions& run_options, const std::vector<std::string>& feed_names,
                             const std::vector<OrtValue>& feeds, const std::vector<std::string>& output_names,
                             std::vector<OrtValue>* p_fetches)

的这个。在这里,run方法对输入数据做了些检查等工作后,变将数据、模型信息,provider信息等,传递给了utils::ExecuteGraph:

1
2
3
utils::ExecuteGraph(*session_state_, feeds_fetches_manager, feeds, *p_fetches,
                            session_options_.execution_mode,
                            run_options.terminate, run_logger))

utils::ExecuteGraph反手又将工作委托给了utils::ExecuteGraphImpl,而utils::ExecuteGraphImpl将会根据前面初始化中确定的各个node的执行先后顺序,找到node类似对对应的kernel,调用他们Compute()方法进行计算。

总结

一个大概流程就是通过使用pybind11将C++接口暴露给Python,Python经过简单封装后提供给用户直接使用。上面有几个关键点值得深入研究:

  1. 模型节点执行顺序的确定;
  2. 模型节点Provider的选择;
  3. 模型解析过程;
  4. 模型推理详细过程;
  5. 模型如何高效推理。
    最后,一图胜千言:

    onnxruntime_exec_flow.png