训练期间在 Keras 回调中查看 y_true 个批次
View y_true of batch in Keras Callback during training
我正在尝试在 Keras 中实现自定义损失函数。它要求我计算每个 y in B
的逆 class 频率之和
它是下面函数的1/epsilon(...)
部分
函数来自 this paper - 第 7 页
注意:我绝对可能误解了论文描述的内容。如果我是
请告诉我
我目前正在尝试使用 Keras 回调和 on_batch_start/end
方法来尝试确定输入批次的 class 频率(这意味着访问批次输入的 y_true
), 但运气不佳。
提前感谢您提供的任何帮助。
编辑: "little luck" 我的意思是我无法在训练期间找到访问单个批次的 y_true
的方法。示例:batch_size = 64
、train_features.shape == (50000, 120, 20)
,我无法在训练期间找到访问单个批次的 y_true
的方法。我可以从 on_batch_start/end
(self.model
) 访问 keras 模型,但我找不到访问批量大小 64 的实际 y_true
的方法。
from tensorflow.python.keras.callbacks import Callback
class FreqReWeight(Callback):
"""
Update learning rate by batch label frequency distribution -- for use with LDAM loss
"""
def __init__(self, C):
self.C = C
def on_train_begin(self, logs={}):
self.model.custom_val = 0
def on_batch_end(self, batch, logs=None):
print('batch index', batch)
print('Model being trained', self.model)
# how can one access the y_true of the batch?
LDAM损失函数
zj = "the j-th output of the model for the j-th class"
EDIT2
损失函数 - 用于在调用损失时进行测试
def LDAM(C):
def loss(y_true, y_pred):
print('shape', y_true.shape) # only prints each epoch, not each batch
return K.mean(y_pred) + C # NOT LDAM, just dummy for testing purposes
return loss
准备数据、编译模型和训练
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
y_train = to_categorical(y_train, 10)
y_test = to_categorical(y_test, 10)
m = 64 # batch_size
model = keras.Sequential()
model.add(Conv2D(32, (3, 3), padding='same',
input_shape=x_train.shape[1:]))
model.add(Activation('relu'))
model.add(Conv2D(32, (3, 3)))
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.25))
model.add(Flatten())
model.add(Dense(512))
model.add(Activation('relu'))
model.add(Dropout(0.5))
model.add(Dense(10))
model.add(Activation('softmax'))
model.compile(loss=LDAM(1), optimizer='sgd', metrics=['accuracy'])
x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255
x_test /= 255
model.fit(x_train, y_train,
batch_size=m,
validation_data=(x_test, y_test),
callbacks=[FreqReWeight(1)])
解决方案
最后问了一个更具体的问题。
都可以找到答案
我正在尝试在 Keras 中实现自定义损失函数。它要求我计算每个 y in B
它是下面函数的1/epsilon(...)
部分
函数来自 this paper - 第 7 页
注意:我绝对可能误解了论文描述的内容。如果我是
请告诉我我目前正在尝试使用 Keras 回调和 on_batch_start/end
方法来尝试确定输入批次的 class 频率(这意味着访问批次输入的 y_true
), 但运气不佳。
提前感谢您提供的任何帮助。
编辑: "little luck" 我的意思是我无法在训练期间找到访问单个批次的 y_true
的方法。示例:batch_size = 64
、train_features.shape == (50000, 120, 20)
,我无法在训练期间找到访问单个批次的 y_true
的方法。我可以从 on_batch_start/end
(self.model
) 访问 keras 模型,但我找不到访问批量大小 64 的实际 y_true
的方法。
from tensorflow.python.keras.callbacks import Callback
class FreqReWeight(Callback):
"""
Update learning rate by batch label frequency distribution -- for use with LDAM loss
"""
def __init__(self, C):
self.C = C
def on_train_begin(self, logs={}):
self.model.custom_val = 0
def on_batch_end(self, batch, logs=None):
print('batch index', batch)
print('Model being trained', self.model)
# how can one access the y_true of the batch?
LDAM损失函数
zj = "the j-th output of the model for the j-th class"
EDIT2
损失函数 - 用于在调用损失时进行测试
def LDAM(C):
def loss(y_true, y_pred):
print('shape', y_true.shape) # only prints each epoch, not each batch
return K.mean(y_pred) + C # NOT LDAM, just dummy for testing purposes
return loss
准备数据、编译模型和训练
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
y_train = to_categorical(y_train, 10)
y_test = to_categorical(y_test, 10)
m = 64 # batch_size
model = keras.Sequential()
model.add(Conv2D(32, (3, 3), padding='same',
input_shape=x_train.shape[1:]))
model.add(Activation('relu'))
model.add(Conv2D(32, (3, 3)))
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.25))
model.add(Flatten())
model.add(Dense(512))
model.add(Activation('relu'))
model.add(Dropout(0.5))
model.add(Dense(10))
model.add(Activation('softmax'))
model.compile(loss=LDAM(1), optimizer='sgd', metrics=['accuracy'])
x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255
x_test /= 255
model.fit(x_train, y_train,
batch_size=m,
validation_data=(x_test, y_test),
callbacks=[FreqReWeight(1)])
解决方案
最后问了一个更具体的问题。
都可以找到答案