首页 > 其他 > 详细

rnn实现三位数加法的训练

时间:2017-01-21 22:07:41      阅读:329      评论:0      收藏:0      [点我收藏+]
#!/usr/bin/env python
# coding=utf-8

from keras.models import Sequential
from keras.layers import Activation, TimeDistributed, Dense, RepeatVector, recurrent
import numpy as np
import string
import random

class CharacterTable(object):

    def __init__(self, maxlen):
        self.chars = string.digits + + 
        self.char_indices = dict((c, i) for i, c in enumerate(self.chars))
        self.indice_chars = dict((i, c) for i, c in enumerate(self.chars))
        self.maxlen = maxlen

    def encode(self, strs, maxlen=None):
        maxlen = maxlen if maxlen else self.maxlen
        vec = np.zeros((maxlen, len(self.chars)))
        for i, c in enumerate(strs):
            vec[i, self.char_indices[c]] = 1
        return vec

    def decode(self, vec, calc_argmax=True):
        if calc_argmax:
            vec = vec.argmax(axis=-1)
        return ‘‘.join(self.indice_chars[x] for x in vec)

def gen_num():
    nums = random.sample(0123456789, random.randint(1, 3))
    return int(‘‘.join(nums))

MAXLEN = 7  # 3+3+1
ctable = CharacterTable(MAXLEN)

questions, expected = [], []
seen = set()
i = 0
while i < 50000:
    a, b = gen_num(), gen_num()
    key = tuple(sorted((a, b)))
    if key in seen:
        continue
    seen.add(key)
    q = {}+{}.format(a, b)
    query = q +  *(7-len(q)) 
    ans = str(a+b)
    ans +=   * (4-len(ans))

    questions.append(query)
    expected.append(ans)
    i += 1
print(total questions, len(questions))

X = np.zeros((len(questions), MAXLEN, len(ctable.chars)), dtype=np.bool)
y = np.zeros((len(questions), 4, len(ctable.chars)), dtype=np.bool)

for i, sent in enumerate(questions):
    X[i] = ctable.encode(sent)

for i, sent in enumerate(expected):
    y[i] = ctable.encode(sent, 4)

model = Sequential()
model.add(recurrent.LSTM(128, input_shape=(7, len(ctable.chars))))
model.add(RepeatVector(4))
model.add(recurrent.LSTM(128, return_sequences=True))
model.add(recurrent.LSTM(128, return_sequences=True))

model.add(TimeDistributed(Dense(len(ctable.chars))))
model.add(Activation(softmax))

model.compile(loss=categorical_crossentropy,
             optimizer=adam,
             metrics=[accuracy])

model.fit(X, y, batch_size=64, nb_epoch=20, validation_split=0.02, verbose=2)

# 测试看看
for i in range(10):
    ind = np.random.randint(0, len(questions)-5)
    x_test, y_test = X[ind:ind+5], y[ind:ind+5]
    y_preds = model.predict_classes(x_test, verbose=0)
    print(Q, ctable.decode(x_test[0]))
    print(T, ctable.decode(y_test[0]))
    print(Pred, ctable.decode(y_preds[0], calc_argmax=False))


json_string = model.to_json()
with open(rnn_add_model.json, wb) as fw:
    fw.write(json_string)
model.save_weights(rnn_add_model.h5)

基本是模仿官网例子,精简了一点,训练约1h, 准确率99.6%

rnn实现三位数加法的训练

原文:http://www.cnblogs.com/jkmiao/p/6337862.html

(0)
(0)
   
举报
评论 一句话评论(0
关于我们 - 联系我们 - 留言反馈 - 联系我们:wmxa8@hotmail.com
© 2014 bubuko.com 版权所有
打开技术之扣,分享程序人生!