首页 > Web开发 > 详细

tensorflow2.0——Resnet网络设计代码

时间:2020-10-01 23:46:51      阅读:117      评论:0      收藏:0      [点我收藏+]
import tensorflow as tf


class BasicBlock(tf.keras.layers.Layer):
    def __init__(self, filter_num, stride=1):
        super(BasicBlock, self).__init__()

        self.conv1 = tf.keras.layers.Conv2D(filter_num, (3, 3), stride=stride, padding=same)
        self.bn1 = tf.keras.layers.BatchNormalization()
        self.relu = tf.keras.layers.Activation(relu)

        self.conv2 = tf.keras.layers.Conv2D(filter_num, (3, 3), stride=1, padding=same)
        self.bn2 = tf.keras.layers.BatchNormalization()

        if stride != 1:
            self.downsample = tf.keras.Sequential()
            self.downsample.add(tf.keras.layers.Conv2D(filter_num, (1, 1), stride=stride))
        else:
            self.downsample = lambda x: x

    def call(self, inputs, training=None):
        out = self.conv1(inputs)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        indentity = self.downsample(inputs)
        output = tf.keras.layers([out, indentity])
        output = tf.nn.relu(output)

        return output


class ResNet(tf.keras.layers.Layer):
    def __init__(self, layer_dims, num_classes=100):  # layer_dims=[2,2,2,2]  表示有4个resblock,每个resblock包含两个basicbloock
        #   num_classes = 100   表示最后的分类有100个
        super(ResNet, self).__init__()
        self.stem = tf.keras.Sequential([
            tf.keras.layers.Conv2D(64, (3, 3), strides=(1, 1)),
            tf.keras.layers.BatchNormalization(),
            tf.keras.layers.Activation(relu),
            tf.keras.layers.MaxPool2D(pool_size=(2, 2), strides=[2, 2], padding=same)
        ])

        self.layer1 = self.build_resblock(64, layer_dims[0])
        self.layer2 = self.build_resblock(128, layer_dims[1],stride=2)
        self.layer3 = self.build_resblock(256, layer_dims[2],stride=2)
        self.layer4 = self.build_resblock(512, layer_dims[3],stride=2)
        #   output[b,512,h,w] 将最后的h*w平均为1个值,这样最后就只有[b,512,1,1]
        self.avgpool = tf.keras.layers.GlobalAveragePooling2D()
        self.fc = tf.keras.layers.Dense(num_classes)

    def call(self, inputs, training=None):
        x = self.stem(inputs)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        #   [b,c]
        x = self.avgpool(x)
        #   [b,100]     100是现在设置的num_classes=100
        x = self.fc(x)
        return x

    def build_resblock(self, filter_num, blocks, stride=1):
        res_blocks = tf.keras.Sequential()
        res_blocks.add(BasicBlock(filter_num, stride))

        for i in range(blocks):
            res_blocks.add(BasicBlock(filter_num, stride=1))

        return res_blocks

def resnet18():
    return ResNet([2,2,2,2])

def resnet34():
    return ResNet([3,4,6,3])

 

tensorflow2.0——Resnet网络设计代码

原文:https://www.cnblogs.com/cxhzy/p/13758763.html

(0)
(0)
   
举报
评论 一句话评论(0
关于我们 - 联系我们 - 留言反馈 - 联系我们:wmxa8@hotmail.com
© 2014 bubuko.com 版权所有
打开技术之扣,分享程序人生!