在keras回调中使用带有自定义参数的自定义函数
Use custom function with custom parameters in keras callback
我正在 keras 中训练一个模型,我想在每个时期之后绘制结果图。我知道 keras 回调提供 "on_epoch_end" 函数,如果有人想在每个纪元之后进行一些计算,可以重载该函数,但我的函数需要一些额外的参数,这些参数在给定时会因元 class 错误而导致代码崩溃。详情如下:
这是我现在的做法,效果很好:-
class NewCallback(Callback):
def on_epoch_end(self, epoch, logs={}): #working fine, printing epoch after each epoch
print("EPOCH IS: "+str(epoch))
epochs=5
batch_size = 16
model_saved=False
if model_saved:
vae.load_weights(args.weights)
else:
# train the autoencoder
vae.fit(x_train,
epochs=epochs,
batch_size=batch_size,
validation_data=(x_test, None),
callbacks=[NewCallback()])
但是我想要这样的回调函数:-
class NewCallback(Callback,models,data,batch_size):
def on_epoch_end(self, epoch, logs={}):
print("EPOCH IS: "+str(epoch))
x=models.predict(data)
plt.plot(x)
plt.savefig(epoch+".png")
如果我这样称呼合适的话:
callbacks=[NewCallback(models, data, batch_size=batch_size)]
我收到这个错误:
TypeError: metaclass conflict: the metaclass of a derived class must be a (non-strict) subclass of the metaclasses of all its bases
我正在寻找一个更简单的解决方案来调用我的函数或解决此元 class 错误,我们将不胜感激!
我认为您想做的是定义一个 class ,它从回调派生并接受模型、数据等...作为构造函数参数。所以:
class NewCallback(Callback):
""" NewCallback descends from Callback
"""
def __init__(self, models, data, batch_size):
""" Save params in constructor
"""
self.models = models
def on_epoch_end(self, epoch, logs={}):
x = self.models.predict(self.data)
如果你想对测试数据进行预测,你可以试试这个
class CustomCallback(keras.callbacks.Callback):
def __init__(self, model, x_test, y_test):
self.model = model
self.x_test = x_test
self.y_test = y_test
def on_epoch_end(self, epoch, logs={}):
y_pred = self.model.predict(self.x_test, self.y_test)
print('y predicted: ', y_pred)
您需要在 model.fit
期间提及回调
model.sequence()
# your model architecture
model.fit(x_train, y_train, epochs=10,
callbacks=[CustomCallback(model, x_test, y_test)])
与on_epoch_end
类似,keras提供了很多其他方法
on_train_begin, on_train_end, on_epoch_begin, on_epoch_end, on_test_begin,
on_test_end, on_predict_begin, on_predict_end, on_train_batch_begin, on_train_batch_end,
on_test_batch_begin, on_test_batch_end, on_predict_batch_begin,on_predict_batch_end
我正在 keras 中训练一个模型,我想在每个时期之后绘制结果图。我知道 keras 回调提供 "on_epoch_end" 函数,如果有人想在每个纪元之后进行一些计算,可以重载该函数,但我的函数需要一些额外的参数,这些参数在给定时会因元 class 错误而导致代码崩溃。详情如下:
这是我现在的做法,效果很好:-
class NewCallback(Callback):
def on_epoch_end(self, epoch, logs={}): #working fine, printing epoch after each epoch
print("EPOCH IS: "+str(epoch))
epochs=5
batch_size = 16
model_saved=False
if model_saved:
vae.load_weights(args.weights)
else:
# train the autoencoder
vae.fit(x_train,
epochs=epochs,
batch_size=batch_size,
validation_data=(x_test, None),
callbacks=[NewCallback()])
但是我想要这样的回调函数:-
class NewCallback(Callback,models,data,batch_size):
def on_epoch_end(self, epoch, logs={}):
print("EPOCH IS: "+str(epoch))
x=models.predict(data)
plt.plot(x)
plt.savefig(epoch+".png")
如果我这样称呼合适的话:
callbacks=[NewCallback(models, data, batch_size=batch_size)]
我收到这个错误:
TypeError: metaclass conflict: the metaclass of a derived class must be a (non-strict) subclass of the metaclasses of all its bases
我正在寻找一个更简单的解决方案来调用我的函数或解决此元 class 错误,我们将不胜感激!
我认为您想做的是定义一个 class ,它从回调派生并接受模型、数据等...作为构造函数参数。所以:
class NewCallback(Callback):
""" NewCallback descends from Callback
"""
def __init__(self, models, data, batch_size):
""" Save params in constructor
"""
self.models = models
def on_epoch_end(self, epoch, logs={}):
x = self.models.predict(self.data)
如果你想对测试数据进行预测,你可以试试这个
class CustomCallback(keras.callbacks.Callback):
def __init__(self, model, x_test, y_test):
self.model = model
self.x_test = x_test
self.y_test = y_test
def on_epoch_end(self, epoch, logs={}):
y_pred = self.model.predict(self.x_test, self.y_test)
print('y predicted: ', y_pred)
您需要在 model.fit
期间提及回调model.sequence()
# your model architecture
model.fit(x_train, y_train, epochs=10,
callbacks=[CustomCallback(model, x_test, y_test)])
与on_epoch_end
类似,keras提供了很多其他方法
on_train_begin, on_train_end, on_epoch_begin, on_epoch_end, on_test_begin,
on_test_end, on_predict_begin, on_predict_end, on_train_batch_begin, on_train_batch_end,
on_test_batch_begin, on_test_batch_end, on_predict_batch_begin,on_predict_batch_end