首页 > Web开发 > 详细

Auto-Encoders实战

时间:2020-12-12 09:39:05      阅读:26      评论:0      收藏:0      [点我收藏+]

Outline

  • Auto-Encoder

  • Variational Auto-Encoders

Auto-Encoder

技术分享图片

创建编解码器

import os
import tensorflow as tf
import numpy as np
from tensorflow import keras
from tensorflow.keras import Sequential, layers
from PIL import Image
from matplotlib import pyplot as plt

tf.random.set_seed(22)
np.random.seed(22)
os.environ[‘TF_CPP_MIN_LOG_LEVEL‘] = ‘2‘
assert tf.version.startswith(‘2.‘)

def save_images(imgs, name):
new_im = Image.new(‘L‘, (280, 280))

index = <span class="hljs-number">0</span>
<span class="hljs-keyword">for</span> i <span class="hljs-keyword">in</span> <span class="hljs-built_in">range</span>(<span class="hljs-number">0</span>, <span class="hljs-number">280</span>, <span class="hljs-number">28</span>):
    <span class="hljs-keyword">for</span> j <span class="hljs-keyword">in</span> <span class="hljs-built_in">range</span>(<span class="hljs-number">0</span>, <span class="hljs-number">280</span>, <span class="hljs-number">28</span>):
        im = imgs[index]
        im = Image.fromarray(im, mode=<span class="hljs-string">‘L‘</span>)
        new_im.paste(im, (i, j))
        index += <span class="hljs-number">1</span>

new_im.save(name)

h_dim = 20 # 784降维20维
batchsz = 512
lr = 1e-3

(x_train, y_train), (x_test, y_test) = keras.datasets.fashion_mnist.load_data()
x_train, x_test = x_train.astype(np.float32) / 255., x_test.astype(
np.float32) / 255.
# we do not need label
train_db = tf.data.Dataset.from_tensor_slices(x_train)
train_db = train_db.shuffle(batchsz * 5).batch(batchsz)
test_db = tf.data.Dataset.from_tensor_slices(x_test)
test_db = test_db.batch(batchsz)

print(x_train.shape, y_train.shape)
print(x_test.shape, y_test.shape)

class AE(keras.Model):
def init(self):
super(AE, self).init()

    <span class="hljs-comment"># Encoders</span>
    self.encoder = Sequential([
        layers.Dense(<span class="hljs-number">256</span>, activation=tf.nn.relu),
        layers.Dense(<span class="hljs-number">128</span>, activation=tf.nn.relu),
        layers.Dense(h_dim)
    ])

    <span class="hljs-comment"># Decoders</span>
    self.decoder = Sequential([
        layers.Dense(<span class="hljs-number">128</span>, activation=tf.nn.relu),
        layers.Dense(<span class="hljs-number">256</span>, activation=tf.nn.relu),
        layers.Dense(<span class="hljs-number">784</span>)
    ])

<span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">call</span>(<span class="hljs-params">self, inputs, training=<span class="hljs-literal">None</span></span>):</span>
    <span class="hljs-comment"># [b,784] ==&gt; [b,19]</span>
    h = self.encoder(inputs)

    <span class="hljs-comment"># [b,10] ==&gt; [b,784]</span>
    x_hat = self.decoder(h)

    <span class="hljs-keyword">return</span> x_hat

model = AE()
model.build(input_shape=(None, 784)) # tensorflow尽量用元组
model.summary()

(60000, 28, 28) (60000,)
(10000, 28, 28) (10000,)
Model: "ae"
_
Layer (type) Output Shape Param #

sequential (Sequential) multiple 236436
_
sequential_1 (Sequential) multiple 237200

Total params: 473,636
Trainable params: 473,636
Non-trainable params: 0
_

训练

optimizer = tf.optimizers.Adam(lr=lr)

for epoch in range(10):

<span class="hljs-keyword">for</span> step, x <span class="hljs-keyword">in</span> <span class="hljs-built_in">enumerate</span>(train_db):

    <span class="hljs-comment"># [b,28,28]==&gt;[b,784]</span>
    x = tf.reshape(x, [<span class="hljs-number">-1</span>, <span class="hljs-number">784</span>])

    <span class="hljs-keyword">with</span> tf.GradientTape() <span class="hljs-keyword">as</span> tape:
        x_rec_logits = model(x)

        rec_loss = tf.losses.binary_crossentropy(x,
                                                 x_rec_logits,
                                                 from_logits=<span class="hljs-literal">True</span>)
        rec_loss = tf.reduce_min(rec_loss)

    grads = tape.gradient(rec_loss, model.trainable_variables)
    optimizer.apply_gradients(<span class="hljs-built_in">zip</span>(grads, model.trainable_variables))

    <span class="hljs-keyword">if</span> step % <span class="hljs-number">100</span> == <span class="hljs-number">0</span>:
        print(epoch, step, <span class="hljs-built_in">float</span>(rec_loss))
        
        <span class="hljs-comment"># evaluation</span>

    x = <span class="hljs-built_in">next</span>(<span class="hljs-built_in">iter</span>(test_db))
    logits = model(tf.reshape(x, [<span class="hljs-number">-1</span>, <span class="hljs-number">784</span>]))
    x_hat = tf.sigmoid(logits)
    <span class="hljs-comment"># [b,784]==&gt;[b,28,28]</span>
    x_hat = tf.reshape(x_hat, [<span class="hljs-number">-1</span>, <span class="hljs-number">28</span>, <span class="hljs-number">28</span>])

    <span class="hljs-comment"># [b,28,28] ==&gt; [2b,28,28]</span>
    x_concat = tf.concat([x, x_hat], axis=<span class="hljs-number">0</span>)
    <span class="hljs-comment"># x_concat = x  # 原始图片</span>
    x_concat = x_hat
    x_concat = x_concat.numpy() * <span class="hljs-number">255.</span>
    x_concat = x_concat.astype(np.uint8)  <span class="hljs-comment"># 保存为整型</span>
    <span class="hljs-keyword">if</span> <span class="hljs-keyword">not</span> os.path.exists(<span class="hljs-string">‘ae_images‘</span>):
        os.mkdir(<span class="hljs-string">‘ae_images‘</span>)
    save_images(x_concat, <span class="hljs-string">‘ae_images/rec_epoch_%d.png‘</span> % epoch)

0 0 0.09717604517936707
0 100 0.12493347376585007
1 0 0.09747321903705597
1 100 0.12291513383388519
2 0 0.10048121958971024
2 100 0.12292417883872986
3 0 0.10093794018030167
3 100 0.12260882556438446
4 0 0.10006923228502274
4 100 0.12275046110153198
5 0 0.0993042066693306
5 100 0.12257824838161469
6 0 0.0967678651213646
6 100 0.12443818897008896
7 0 0.0965462476015091
7 100 0.12179268896579742
8 0 0.09197664260864258
8 100 0.12110235542058945
9 0 0.0913471132516861
9 100 0.12342415750026703

Auto-Encoders实战

原文:https://www.cnblogs.com/abdm-989/p/14123449.html

(0)
(0)
   
举报
评论 一句话评论(0
关于我们 - 联系我们 - 留言反馈 - 联系我们:wmxa8@hotmail.com
© 2014 bubuko.com 版权所有
打开技术之扣,分享程序人生!