onnx 部署 lstm-crf 类模型的坑#

onnx 没有支持 crf#

这个问题是因为 crf 中 viterbi-decode 步骤用了 for 循环,导致 pytorch Tensor 无法 tracem 从而无法转换为 onnx。 即使有些 crf 的 python 实现在转换过程中没有报错、警告,实际使用时依然是不正常的。

解决方法:

  1. python 层面,将 viterbi-decode 核心部分独立出来,实现为一个 torch.autograd.Function 算子

  2. onnx 层面,使用 c++ 自定义 onnx 算子,相当于用 C++ 实现了一遍 viterbi-decode。

  3. 将 1 中 Function 和 2 的 onnx 算子的名字关联起来。

说会 viterbi decode 算法,它接收 隐层序列(sequence)/转移矩阵(trans_matrix) 为参数,输出最可能的 BIOES 序列。

所以 torch 层面的 Function 实现如下:

class MyViterbiDecode(torch.autograd.Function):
    @staticmethod
    def forward(ctx, sequence, trans_matrix):
        # viterbi_decode 使用一般实现即可
        path = viterbi_decode(sequence, trans_matrix)
        return torch.tensor(path, dtype=torch.int32)

    @staticmethod
    def symbolic(g: torch.Graph, sequence: torch.Value, trans_matrix: torch.Value):
        """Function 算子关联的 C++ 实现的自定义 onnx 算子名称"""
        return g.op("my.domain::ViterbiDecode", seq, mat)

完整的 Function 应该实现 forward/backward, viterbi-decode 只会用在预测流程, 所以不需要实现 backward。

其中, MyViterbiDecode::symbolic 函数是为了 torch.onnx.export 将模型导出为 onnx 时将它转成一个自定义 onnx operation。

为了使用上面的 Function, 原有模型代码改造为:

    class LstmCrfModal(torch.nn.Module):
        ... ...

        def forward(self, ...):
            ... ...
            if not self.training:
-               path = viterbi_decode(sequence, trans_matrix)
+               path = MyViterbiDecode.apply(sequence, trans_matrix)
            ... ...

做到这一步,模型已经可以正确导出了,剩下的,我们需要使用 C++ 实现一遍 ViterbiDecode,就可以把导出的 onnx 放到 onnxruntime 上运行。

C++ 这边我是对照 python 代码实现的,因为不知道原来的算法魔改了什么,一般理想情况下到 github 找个库就好。 另外 python 这边我们也改了很多东西,原来的算法代码实现很不规范,各种 torch 转 numpy 转 python 对象接着花式低效操作, 说多了都是泪。

现在假设已经有了一个 viterbi_decode 的 C++ 实现,下面的代码是封装为 onnxruntime 算子的。

using Ort::Custom::Tensor;
extern void viterbiDecode(float* sequence, float *matrix, int batchSize, int seqLen, int tagNum, int32_t* path /*out*/);

void KernelViterbiDecode(const Tensor<float> &Sequence, const Tensor<float> &TransMatrix, Tensor<int32_t> &Path)
{
    const std::vector<int64_t> &shape = Sequence.Shape();
    const auto batchSize = shape.at(0);
    const auto seqLen = shape.at(1);
    const auto tagNum = shape.at(2);

    const float* seqRaw = Sequence.Data();
    const float* matRaw = TransMatrix.Data();
    std::vector<int64_t> outShape = {batchSize, seqLen};
    int32_t* pathRaw = Path.Allocate(outShape);

    viterbiDecode(seqRaw, matRaw, batchSize, seqLen, tagNum, pathRaw);
}

// 仿照官网样例注册进去
OrtStatus *ORT_API_CALL RegisterCustomOps(OrtSessionOptions *options, const OrtApiBase *api)
{
    ... ...

    static const std::unique_prt<LiteOp> viterbiDecode{Ort::Custom::CreateLiteCustomOp("ViterbiDecode", "CPUExecutionProvider", KernelViterbiDecode)};

    Ort::CustomOpDomain domain{"my.domain"}
    domain.Add(viterbiDecode.get());

    // 下面和官网样例代码一样
    ... ...
}

上面代码编译为 custom_op.so 后,可以在 onnxruntime 中使用:

import onnxruntime as ort

sess_opt = ort.SessionOptions()
sess_opt.register_custom_ops_library("/path/to/custom_op.so")
sess = ort.InferenceSession("lstm_crf.onnx", sess_options=sess_opt, providers=...)

result = sess.run(None, { ... })

如果完全按照官网样例的话,只有 onnxruntime 发布包中的头文件是不够的,最好把 git 源码里面 include 目录整个扣下来。 另外 onnxruntime 依赖了微软的 GSL 库,好在这个库是 header-only, 也直接下载下来放到项目目录就好。

相关链接:

  1. pytorch autograd.Function

  2. pytorch onnx.export

  3. onnx custom op

  4. 关于 torch.Graph, 见 pytorch 代码: onnx/utils.py:_graph_op 函数