首页 > 其他 > 详细

Keras class_weight和sample_weight用法

时间:2019-12-19 21:51:21      阅读:591      评论:0      收藏:0      [点我收藏+]

搬运: https://stackoverflow.com/questions/57610804/when-is-the-timing-to-use-sample-weights-in-keras

import tensorflow as tf
import numpy as np

data_size = 100
input_size=3
classes=3

x_train = np.random.rand(data_size ,input_size)
y_train= np.random.randint(0,classes,data_size )
#sample_weight_train = np.random.rand(data_size)
x_val = np.random.rand(data_size ,input_size)
y_val= np.random.randint(0,classes,data_size )
#sample_weight_val = np.random.rand(data_size )

inputs = tf.keras.layers.Input(shape=(input_size))
pred=tf.keras.layers.Dense(classes, activation='softmax')(inputs)

model = tf.keras.models.Model(inputs=inputs, outputs=pred)

loss = tf.keras.losses.sparse_categorical_crossentropy
metrics = tf.keras.metrics.sparse_categorical_accuracy

model.compile(loss=loss , metrics=[metrics], optimizer='adam')

# Make model static, so we can compare it between different scenarios
for layer in model.layers:
    layer.trainable = False

# base model no weights (same result as without class_weights)
# model.fit(x=x_train,y=y_train, validation_data=(x_val,y_val))
class_weights={0:1.,1:1.,2:1.}
model.fit(x=x_train,y=y_train, class_weight=class_weights, validation_data=(x_val,y_val))
# which outputs:
> loss: 1.1882 - sparse_categorical_accuracy: 0.3300 - val_loss: 1.1965 - val_sparse_categorical_accuracy: 0.3100

#changing the class weights to zero, to check which loss and metric that is affected
class_weights={0:0,1:0,2:0}
model.fit(x=x_train,y=y_train, class_weight=class_weights, validation_data=(x_val,y_val))
# which outputs:
> loss: 0.0000e+00 - sparse_categorical_accuracy: 0.3300 - val_loss: 1.1945 - val_sparse_categorical_accuracy: 0.3100

#changing the sample_weights to zero, to check which loss and metric that is affected
sample_weight_train = np.zeros(100)
sample_weight_val = np.zeros(100)
model.fit(x=x_train,y=y_train,sample_weight=sample_weight_train, validation_data=(x_val,y_val,sample_weight_val))
# which outputs:
> loss: 0.0000e+00 - sparse_categorical_accuracy: 0.3300 - val_loss: 1.1931 - val_sparse_categorical_accuracy: 0.3100

class_weight: output 变量的权重
sample_weight: data sample 的权重

Keras class_weight和sample_weight用法

原文:https://www.cnblogs.com/yaos/p/12069527.html

(0)
(0)
   
举报
评论 一句话评论(0
关于我们 - 联系我们 - 留言反馈 - 联系我们:wmxa8@hotmail.com
© 2014 bubuko.com 版权所有
打开技术之扣,分享程序人生!