首页 > 其他 > 详细

Datawhale 学CV--task3 代码理解

时间:2020-05-26 12:47:08      阅读:40      评论:0      收藏:0      [点我收藏+]

尝试修改网络结构的记录:

1、resnet18修改为vgg16,epoch=2时效果差点,修改:

model_conv = models.vgg16(pretrained=True)

#model_conv = models.restnet18(pretrained=True)

2、如果每个stage结构都一样,可以写如下,再传参数。

self.conv1 = nn.Sequential(
nn.Conv2d()
nn.BatchNormal()
nn.PReLU()
nn.MaxPooling()
nn.Dropout()
)

传参数:conv2d(n_in,n_out,kernel,stride,padding)

batchnormal(n_out)待处理数据的channel;batchnormal(n_out,0.1)包含了?

PRelu可以换成Relu,无参数;

Maxpooling(2)表示(2,2)的maxpooling

Dropout(0.2),网络小时取值也小,一般0.2--0.5

3. 代码记录如下;

def __init__(self):
super(SVHN_Model1, self).__init__()
# model_conv = models.vgg16(pretrained=True)

self.fc1 = nn.Linear(100, 11)
self.fc2 = nn.Linear(100, 11)
self.fc3 = nn.Linear(100, 11)
self.fc4 = nn.Linear(100, 11)
self.fc5 = nn.Linear(100, 11)

def conv_bn(inp, oup):
return nn.Sequential(
nn.Conv2d(inp, oup, 5, 1, 2),
nn.BatchNorm2d(oup),
nn.ReLU(inplace=True),
nn.MaxPool2d(2,2,1),
nn.Dropout(0.2),
)

self.cnn = nn.Sequential(
conv_bn( 3, 48),
conv_bn( 48, 64),
conv_bn( 64, 128),
conv_bn( 128,160),
conv_bn( 160, 192),
conv_bn( 192, 192),
conv_bn( 192, 192),
conv_bn( 192, 192), )

def dense_1(inp, oup):
return nn.Sequential(
nn.Linear(inp,oup),
nn.ReLU(inplace=True),
)
self.fc_2 = nn.Sequential(

dense_1( 768, 512),
dense_1( 512, 100),)

def forward(self, img):
feat = self.cnn(img)
#print(feat.shape)
#feat = feat.view(feat.shape[0], -1)
feat = feat.view(feat.shape[0], -1)
feat = self.fc_2(feat)

c1 = self.fc1(feat)
c2 = self.fc2(feat)
c3 = self.fc3(feat)
c4 = self.fc4(feat)
c5 = self.fc5(feat)
return c1, c2, c3, c4, c5

Datawhale 学CV--task3 代码理解

原文:https://www.cnblogs.com/haiyanli/p/12964599.html

(0)
(0)
   
举报
评论 一句话评论(0
关于我们 - 联系我们 - 留言反馈 - 联系我们:wmxa8@hotmail.com
© 2014 bubuko.com 版权所有
打开技术之扣,分享程序人生!