import tensorflow as tf class BasicBlock(tf.keras.layers.Layer): def __init__(self, filter_num, stride=1): super(BasicBlock, self).__init__() self.conv1 = tf.keras.layers.Conv2D(filter_num, (3, 3), stride=stride, padding=‘same‘) self.bn1 = tf.keras.layers.BatchNormalization() self.relu = tf.keras.layers.Activation(‘relu‘) self.conv2 = tf.keras.layers.Conv2D(filter_num, (3, 3), stride=1, padding=‘same‘) self.bn2 = tf.keras.layers.BatchNormalization() if stride != 1: self.downsample = tf.keras.Sequential() self.downsample.add(tf.keras.layers.Conv2D(filter_num, (1, 1), stride=stride)) else: self.downsample = lambda x: x def call(self, inputs, training=None): out = self.conv1(inputs) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) indentity = self.downsample(inputs) output = tf.keras.layers([out, indentity]) output = tf.nn.relu(output) return output class ResNet(tf.keras.layers.Layer): def __init__(self, layer_dims, num_classes=100): # layer_dims=[2,2,2,2] 表示有4个resblock,每个resblock包含两个basicbloock # num_classes = 100 表示最后的分类有100个 super(ResNet, self).__init__() self.stem = tf.keras.Sequential([ tf.keras.layers.Conv2D(64, (3, 3), strides=(1, 1)), tf.keras.layers.BatchNormalization(), tf.keras.layers.Activation(‘relu‘), tf.keras.layers.MaxPool2D(pool_size=(2, 2), strides=[2, 2], padding=‘same‘) ]) self.layer1 = self.build_resblock(64, layer_dims[0]) self.layer2 = self.build_resblock(128, layer_dims[1],stride=2) self.layer3 = self.build_resblock(256, layer_dims[2],stride=2) self.layer4 = self.build_resblock(512, layer_dims[3],stride=2) # output[b,512,h,w] 将最后的h*w平均为1个值,这样最后就只有[b,512,1,1] self.avgpool = tf.keras.layers.GlobalAveragePooling2D() self.fc = tf.keras.layers.Dense(num_classes) def call(self, inputs, training=None): x = self.stem(inputs) x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) x = self.layer4(x) # [b,c] x = self.avgpool(x) # [b,100] 100是现在设置的num_classes=100 x = self.fc(x) return x def build_resblock(self, filter_num, blocks, stride=1): res_blocks = tf.keras.Sequential() res_blocks.add(BasicBlock(filter_num, stride)) for i in range(blocks): res_blocks.add(BasicBlock(filter_num, stride=1)) return res_blocks def resnet18(): return ResNet([2,2,2,2]) def resnet34(): return ResNet([3,4,6,3])
原文:https://www.cnblogs.com/cxhzy/p/13758763.html