onnx 部署 lstm-crf 类模型的坑#
onnx 没有支持 crf#
这个问题是因为 crf 中 viterbi-decode 步骤用了 for 循环,导致 pytorch Tensor 无法 tracem 从而无法转换为 onnx。 即使有些 crf 的 python 实现在转换过程中没有报错、警告,实际使用时依然是不正常的。
解决方法:
python 层面,将 viterbi-decode 核心部分独立出来,实现为一个 torch.autograd.Function 算子
onnx 层面,使用 c++ 自定义 onnx 算子,相当于用 C++ 实现了一遍 viterbi-decode。
将 1 中 Function 和 2 的 onnx 算子的名字关联起来。
说会 viterbi decode 算法,它接收 隐层序列(sequence)/转移矩阵(trans_matrix) 为参数,输出最可能的 BIOES 序列。
所以 torch 层面的 Function 实现如下:
class MyViterbiDecode(torch.autograd.Function):
@staticmethod
def forward(ctx, sequence, trans_matrix):
# viterbi_decode 使用一般实现即可
path = viterbi_decode(sequence, trans_matrix)
return torch.tensor(path, dtype=torch.int32)
@staticmethod
def symbolic(g: torch.Graph, sequence: torch.Value, trans_matrix: torch.Value):
"""Function 算子关联的 C++ 实现的自定义 onnx 算子名称"""
return g.op("my.domain::ViterbiDecode", seq, mat)
完整的 Function 应该实现 forward/backward, viterbi-decode 只会用在预测流程, 所以不需要实现 backward。
其中, MyViterbiDecode::symbolic 函数是为了 torch.onnx.export 将模型导出为 onnx 时将它转成一个自定义 onnx operation。
为了使用上面的 Function, 原有模型代码改造为:
class LstmCrfModal(torch.nn.Module):
... ...
def forward(self, ...):
... ...
if not self.training:
- path = viterbi_decode(sequence, trans_matrix)
+ path = MyViterbiDecode.apply(sequence, trans_matrix)
... ...
做到这一步,模型已经可以正确导出了,剩下的,我们需要使用 C++ 实现一遍 ViterbiDecode,就可以把导出的 onnx 放到 onnxruntime 上运行。
C++ 这边我是对照 python 代码实现的,因为不知道原来的算法魔改了什么,一般理想情况下到 github 找个库就好。 另外 python 这边我们也改了很多东西,原来的算法代码实现很不规范,各种 torch 转 numpy 转 python 对象接着花式低效操作, 说多了都是泪。
现在假设已经有了一个 viterbi_decode 的 C++ 实现,下面的代码是封装为 onnxruntime 算子的。
using Ort::Custom::Tensor;
extern void viterbiDecode(float* sequence, float *matrix, int batchSize, int seqLen, int tagNum, int32_t* path /*out*/);
void KernelViterbiDecode(const Tensor<float> &Sequence, const Tensor<float> &TransMatrix, Tensor<int32_t> &Path)
{
const std::vector<int64_t> &shape = Sequence.Shape();
const auto batchSize = shape.at(0);
const auto seqLen = shape.at(1);
const auto tagNum = shape.at(2);
const float* seqRaw = Sequence.Data();
const float* matRaw = TransMatrix.Data();
std::vector<int64_t> outShape = {batchSize, seqLen};
int32_t* pathRaw = Path.Allocate(outShape);
viterbiDecode(seqRaw, matRaw, batchSize, seqLen, tagNum, pathRaw);
}
// 仿照官网样例注册进去
OrtStatus *ORT_API_CALL RegisterCustomOps(OrtSessionOptions *options, const OrtApiBase *api)
{
... ...
static const std::unique_prt<LiteOp> viterbiDecode{Ort::Custom::CreateLiteCustomOp("ViterbiDecode", "CPUExecutionProvider", KernelViterbiDecode)};
Ort::CustomOpDomain domain{"my.domain"}
domain.Add(viterbiDecode.get());
// 下面和官网样例代码一样
... ...
}
上面代码编译为 custom_op.so 后,可以在 onnxruntime 中使用:
import onnxruntime as ort
sess_opt = ort.SessionOptions()
sess_opt.register_custom_ops_library("/path/to/custom_op.so")
sess = ort.InferenceSession("lstm_crf.onnx", sess_options=sess_opt, providers=...)
result = sess.run(None, { ... })
如果完全按照官网样例的话,只有 onnxruntime 发布包中的头文件是不够的,最好把 git 源码里面 include 目录整个扣下来。 另外 onnxruntime 依赖了微软的 GSL 库,好在这个库是 header-only, 也直接下载下来放到项目目录就好。
相关链接:
关于 torch.Graph, 见 pytorch 代码: onnx/utils.py:
_graph_op
函数