博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
pytorch1.0 用torch script导出模型
阅读量:7240 次
发布时间:2019-06-29

本文共 1395 字,大约阅读时间需要 4 分钟。

python的易上手和pytorch的动态图特性,使得pytorch在学术研究中越来越受欢迎,但在生产环境,碍于python的GIL等特性,可能达不到高并发、低延迟的要求,存在需要用c++接口的情况。除了将模型导出为ONNX外,pytorch1.0给出了新的解决方案:pytorch 训练模型 - 通过torch script中间脚本保存模型 -- C++加载模型。最近工作需要尝试做了转换,总结一下步骤和遇到的坑。

用torch script把torch模型转成c++接口可读的模型有两种方式:trace && script. trace比script简单,但只适合结构固定的网络模型,即forward中没有控制流的情况,因为trace只会保存运行时实际走的路径。如果forward函数中有控制流,需要用script方式实现。

trace顾名思义,就是沿着数据运算的路径走一遍,官方例子:

 
import torchdef foo(x, y): return 2*x + y traced_foo = torch.jit.trace(foo, (torch.rand(3), torch.rand(3)))

 

 

 

 

script稍复杂,主要改三处:

1. Model由之前继承 nn.Model 改为继承 torch.jit.ScriptModule

2. forward函数前加 @torch.jit.script_method

3. 其他需要调用的函数前加 @torch.jit.script

 

踩过的坑&&解决方法:

A. torch script默认函数或方法的参数都是Tensor类型的,如果不是需要说明,不然调用非Tensor参数时会报类型不符的编译错误。

python3可以直接:

def example_func(param_1: Tensor, param_2: int, param_3: List[int]):

 

 

python2需要用type注释:

def example_func(param_1, param_2, param_3):

#type: (Tensor, int, List[int]) -> Tensor

 

 

 

B. model的方法中forward加@torch.jit.script_method, __init__函数不用

C. 前面说过,torch scrip支持的函数是pytorch的子集,意味着有一部分函数不支持,例如: not boolean,pass, List的切片赋值,CPU和GPU切换的value.to( ), 需要想办法绕过去。看github上讨论区说新版好像已经支持not操作了,没有验证。

 

结论:pytorch 1.0目前的预览版还有比较多优化的空间,至少是在torch script支持的函数集合上,不建议使用,等稳定版发布再看看吧。

  

 

原创内容,转载请注明出处。

 

参考资料:

https://pytorch.org/docs/master/jit.html

https://pytorch.org/tutorials/beginner/deploy_seq2seq_hybrid_frontend_tutorial.html

转载于:https://www.cnblogs.com/Arborday/p/9890999.html

你可能感兴趣的文章
一个完整的大作业
查看>>
Hadoop阅读笔记(一)——强大的MapReduce
查看>>
vue keep-alive保存路由状态1 (接下篇)
查看>>
这是一份极其粗糙的莫比乌斯函数学习笔记
查看>>
我的XHTML学习笔记
查看>>
Jenkins配置自动化构建
查看>>
私有IP
查看>>
Servelt工具类,基于Tomcat8以上版本,提供常见工具方法,包括:cookie查找和删除、文件下载设置、文件上传的表单解析、上传数据和session中数据的比较、多级目录的创建...
查看>>
PHP常见数组排序方法小结
查看>>
vue简单项目实际应用
查看>>
第七次作业
查看>>
主键,外键
查看>>
anguar相关
查看>>
Python 单例模式
查看>>
cocoaPods管理的类换了台电脑,出错了file not found
查看>>
可以打开mdb文件的小软件
查看>>
Windows 8 Metro App开发[4]弹出画面(Flayouts)
查看>>
如何用java读取properties文件
查看>>
hdu1166 (bit)
查看>>
python模块目录文件后续
查看>>