首页 > Web开发 > 详细

PSPnet模型结构的实现代码

时间:2019-04-14 11:15:32      阅读:492      评论:0      收藏:0      [点我收藏+]

1
import torch 2 import torch.nn.functional as F 3 from torch import nn 4 from torchvision import models 5 6 from utils import initialize_weights 7 from utils.misc import Conv2dDeformable 8 from .config import res101_path 9 10 //金字塔模块,将从前面卷积结构提取的特征分别进行不同的池化操作,得到不同感受野以及语境信息 11 class _PyramidPoolingModule(nn.Module): 12 def __init__(self, in_dim, reduction_dim, setting): 13 super(_PyramidPoolingModule, self).__init__() 14 self.features = [] 15 for s in setting: //对应不同的池化操作,单个bin,多个bin 16 self.features.append(nn.Sequential( 17 nn.AdaptiveAvgPool2d(s), 18 nn.Conv2d(in_dim, reduction_dim, kernel_size=1, bias=False), 19 nn.BatchNorm2d(reduction_dim, momentum=.95), 20 nn.ReLU(inplace=True) 21 )) 22 self.features = nn.ModuleList(self.features) 23 24 def forward(self, x): 25 x_size = x.size() 26 out = [x] 27 for f in self.features: 28 out.append(F.upsample(f(x), x_size[2:], mode=bilinear)) 29 out = torch.cat(out, 1) 30 return out 31 32 //整个pspnet网络的结构 33 class PSPNet(nn.Module): 34 def __init__(self, num_classes, pretrained=True, use_aux=True): 35 super(PSPNet, self).__init__() 36 self.use_aux = use_aux 37 resnet = models.resnet101() //采用resnet101作为骨干模型,提取特征 38 if pretrained: 39 resnet.load_state_dict(torch.load(res101_path)) 40 self.layer0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool) 41 self.layer1, self.layer2, self.layer3, self.layer4 = resnet.layer1, resnet.layer2, resnet.layer3, resnet.layer4 42     //设置带洞卷积的参数(dilation),以及卷积的参数 43 for n, m in self.layer3.named_modules(): 44 if conv2 in n: 45 m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1, 1) 46 elif downsample.0 in n: 47 m.stride = (1, 1) 48 for n, m in self.layer4.named_modules(): 49 if conv2 in n: 50 m.dilation, m.padding, m.stride = (4, 4), (4, 4), (1, 1) 51 elif downsample.0 in n: 52 m.stride = (1, 1) 53     //加入ppm模块,以及最后的连接层(卷积) 54 self.ppm = _PyramidPoolingModule(2048, 512, (1, 2, 3, 6)) 55 self.final = nn.Sequential( 56 nn.Conv2d(4096, 512, kernel_size=3, padding=1, bias=False), 57 nn.BatchNorm2d(512, momentum=.95), 58 nn.ReLU(inplace=True), 59 nn.Dropout(0.1), 60 nn.Conv2d(512, num_classes, kernel_size=1) 61 ) 62 63 if use_aux: 64 self.aux_logits = nn.Conv2d(1024, num_classes, kernel_size=1) 65 initialize_weights(self.aux_logits) 66 # 初始化权重 67 initialize_weights(self.ppm, self.final) 68 69 def forward(self, x): 70 x_size = x.size() 71 x = self.layer0(x) 72 x = self.layer1(x) 73 x = self.layer2(x) 74 x = self.layer3(x) 75 if self.training and self.use_aux: 76 aux = self.aux_logits(x) 77 x = self.layer4(x) 78 x = self.ppm(x) 79 x = self.final(x) 80 if self.training and self.use_aux: 81 return F.upsample(x, x_size[2:], mode=bilinear), F.upsample(aux, x_size[2:], mode=bilinear) 82 return F.upsample(x, x_size[2:], mode=bilinear)

 

PSPnet模型结构的实现代码

原文:https://www.cnblogs.com/ywheunji/p/10704237.html

(0)
(0)
   
举报
评论 一句话评论(0
关于我们 - 联系我们 - 留言反馈 - 联系我们:wmxa8@hotmail.com
© 2014 bubuko.com 版权所有
打开技术之扣,分享程序人生!