首页 > Web开发 > 详细

使用netron工具可视化pytorch模型

时间:2021-03-11 10:26:22      阅读:48      评论:0      收藏:0      [点我收藏+]

netron是微软小哥lutzroeder的一个广受好评的开源项目,地址https://github.com/lutzroeder/Netro

可惜,默认支持的格式中并不包括pytorch,可能当年小哥面试facebook被拒了,:)

Netron supports ONNX (.onnx.pb.pbtxt), Keras (.h5.keras), Core ML (.mlmodel), Caffe (.caffemodel.prototxt), Caffe2 (predict_net.pbpredict_net.pbtxt), MXNet (.model-symbol.json), NCNN (.param) and TensorFlow Lite (.tflite).

1. 安装netron

pip install netron

2. 测试代码

由于不支持默认的pytorch模型格式(.pth),因此需要存为onnx,庆幸pytorch支持!

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.onnx

import netron


class model(nn.Module):
    def __init__(self):
        super(model, self).__init__()
        self.block1 = nn.Sequential(
            nn.Conv2d(64, 64, 3, padding=1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 32, 1, bias=False),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 64, 3, padding=1, bias=False),
            nn.BatchNorm2d(64)
        )

        self.conv1 = nn.Conv2d(3, 64, 3, padding=1, bias=False)
        self.output = nn.Sequential(
            nn.Conv2d(64, 1, 3, padding=1, bias=True),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.conv1(x)
        identity = x
        x = F.relu(self.block1(x) + identity)
        x = self.output(x)
        return x


d = torch.rand(1, 3, 416, 416)
m = model()
o = m(d)

onnx_path = "onnx_model_name.onnx"
torch.onnx.export(m, d, onnx_path)

netron.start(onnx_path)

3. 结果

执行上面代码后,会调用本地浏览器打开,形式和tensorboard差不多

Serving ‘onnx_model_name.onnx‘ at http://localhost:8080

技术分享图片

使用netron工具可视化pytorch模型

原文:https://www.cnblogs.com/shuimuqingyang/p/14515579.html

(0)
(0)
   
举报
评论 一句话评论(0
关于我们 - 联系我们 - 留言反馈 - 联系我们:wmxa8@hotmail.com
© 2014 bubuko.com 版权所有
打开技术之扣,分享程序人生!