首页 > 其他 > 详细

吴裕雄--天生自然TensorFlow2教程:数据加载

时间:2020-01-02 22:36:28      阅读:98      评论:0      收藏:0      [点我收藏+]
import tensorflow as tf
from tensorflow import keras

# train: 60k | test: 10k
(x, y), (x_test, y_test) = keras.datasets.mnist.load_data()

x.shape
y.shape
# 0纯黑、255纯白
x.min(), x.max(), x.mean()
x_test.shape, y_test.shape
# 0-9有10种分类结果
y_onehot = tf.one_hot(y, depth=10)
y_onehot[:2]
# train: 50k | test: 10k
(x, y), (x_test, y_test) = keras.datasets.cifar10.load_data()
x.shape, y.shape, x_test.shape, y_test.shape
x.min(), x.max()
db = tf.data.Dataset.from_tensor_slices(x_test)
next(iter(db)).shape
db = tf.data.Dataset.from_tensor_slices((x_test, y_test))
next(iter(db))[0].shape
打乱数据
db = tf.data.Dataset.from_tensor_slices((x_test, y_test))
db = db.shuffle(10000)
数据预处理
def preprocess(x, y):
    x = tf.cast(x, dtype=tf.float32) / 255.
    y = tf.cast(y, dtype=tf.int32)
    y = tf.one_hot(y, depth=10)
    return x, y
db2 = db.map(preprocess)
res = next(iter(db2))
res[0].shape, res[1].shape
一次性得到多张照片
db3 = db2.batch(32)
res = next(iter(db3))
res[0].shape, res[1].shape
db_iter = iter(db3)
while True:
    next(db_iter)
repeat()
# 迭代不退出
db4 = db3.repeat()
# 迭代两次退出
db3 = db3.repeat(2)
def prepare_mnist_features_and_labels(x, y):
    x = tf.cast(x, tf.float32) / 255.
    y = tf.cast(y, tf.int64)
    return x, y

def mnist_dataset():
    (x, y), (x_val, y_val) = datasets.fashion_mnist.load_data()
    y = tf.one_hot(y, depth=10)
    y_val = tf.one_hot(y_val, depth=10)

    ds = tf.data.Dataset.from_tensor_slices((x, y))
    ds = ds.map(prepare_mnist_features_and_labels)
    ds = ds.shffle(60000).batch(100)
    ds_val = tf.data.Dataset.from_tensor_slices((x_val, y_val))
    ds_val = ds_val.map(prepare_mnist_features_and_labels)
    ds_val = ds_val.shuffle(10000).batch(100)
    return ds, ds_val

 

吴裕雄--天生自然TensorFlow2教程:数据加载

原文:https://www.cnblogs.com/tszr/p/12141969.html

(0)
(0)
   
举报
评论 一句话评论(0
关于我们 - 联系我们 - 留言反馈 - 联系我们:wmxa8@hotmail.com
© 2014 bubuko.com 版权所有
打开技术之扣,分享程序人生!