使用 Gradient Tape, TF2.6 的自定义损失函数
Custom loss function with Gradient Tape, TF2.6
我正在尝试在我的 Keras
序列模型 (TensorFlow 2.6.0) 中使用自定义损失函数。这种自定义损失(理想情况下)将计算数据损失加上物理方程(比如扩散方程、Navier Stokes 等)的残差。此残差基于模型输出导数及其输入,我想使用 GradientTape
.
在这个 MWE 中,我删除了数据损失项和其他方程损失,只使用了输出对输入的导数。可以找到数据集 here.
from numpy import loadtxt
from keras.models import Sequential
from keras.layers import Dense
import tensorflow as tf #tf.__version__ = '2.6.0'
# load the dataset
dataset = loadtxt('pima-indians-diabetes.csv', delimiter=',')
# split into input (X) and output (y) variables
X = dataset[:,0:8] #X.shape = (768, 8)
y = dataset[:,8]
X = tf.convert_to_tensor(X, dtype=tf.float32)
y = tf.convert_to_tensor(y, dtype=tf.float32)
def customLoss(y_true,y_pred):
x_tensor = tf.convert_to_tensor(model.input, dtype=tf.float32)
# x_tensor = tf.cast(x_tensor, tf.float32)
with tf.GradientTape() as t:
t.watch(x_tensor)
output = model(x_tensor)
DyDX = t.gradient(output, x_tensor)
dy_t = DyDX[:, 5:6]
R_pred=dy_t
# loss_data = tf.reduce_mean(tf.square(yTrue - yPred), axis=-1)
loss_PDE = tf.reduce_mean(tf.square(R_pred))
return loss_PDE
model = Sequential()
model.add(Dense(12, input_dim=8, activation='relu'))
model.add(Dense(12, activation='relu'))
model.add(Dense(12, activation='relu'))
model.add(Dense(1, activation='sigmoid'))
model.compile(loss=customLoss, optimizer='adam', metrics=['accuracy'])
model.fit(X, y, epochs=15)
执行后,我得到这个ValueError
:
ValueError: Passed in object of type <class 'keras.engine.keras_tensor.KerasTensor'>, not tf.Tensor
当我将 loss=customLoss
更改为 loss='mse'
时,模型开始训练,但使用 customLoss
才是重点。有什么想法吗?
问题似乎出在损失函数中的model.input,如果我对你的代码理解正确,你可以使用损失:
def custom_loss_pass(model, x_tensor):
def custom_loss(y_true,y_pred):
with tf.GradientTape() as t:
t.watch(x_tensor)
output = model(x_tensor)
DyDX = t.gradient(output, x_tensor)
dy_t = DyDX[:, 5:6]
R_pred=dy_t
# loss_data = tf.reduce_mean(tf.square(yTrue - yPred), axis=-1)
loss_PDE = tf.reduce_mean(tf.square(R_pred))
return loss_PDE
return custom_loss
然后:
model.compile(loss=custom_loss_pass(model, X), optimizer='adam', metrics=['accuracy'])
我不确定它是否满足您的要求,但至少它有效!
我正在尝试在我的 Keras
序列模型 (TensorFlow 2.6.0) 中使用自定义损失函数。这种自定义损失(理想情况下)将计算数据损失加上物理方程(比如扩散方程、Navier Stokes 等)的残差。此残差基于模型输出导数及其输入,我想使用 GradientTape
.
在这个 MWE 中,我删除了数据损失项和其他方程损失,只使用了输出对输入的导数。可以找到数据集 here.
from numpy import loadtxt
from keras.models import Sequential
from keras.layers import Dense
import tensorflow as tf #tf.__version__ = '2.6.0'
# load the dataset
dataset = loadtxt('pima-indians-diabetes.csv', delimiter=',')
# split into input (X) and output (y) variables
X = dataset[:,0:8] #X.shape = (768, 8)
y = dataset[:,8]
X = tf.convert_to_tensor(X, dtype=tf.float32)
y = tf.convert_to_tensor(y, dtype=tf.float32)
def customLoss(y_true,y_pred):
x_tensor = tf.convert_to_tensor(model.input, dtype=tf.float32)
# x_tensor = tf.cast(x_tensor, tf.float32)
with tf.GradientTape() as t:
t.watch(x_tensor)
output = model(x_tensor)
DyDX = t.gradient(output, x_tensor)
dy_t = DyDX[:, 5:6]
R_pred=dy_t
# loss_data = tf.reduce_mean(tf.square(yTrue - yPred), axis=-1)
loss_PDE = tf.reduce_mean(tf.square(R_pred))
return loss_PDE
model = Sequential()
model.add(Dense(12, input_dim=8, activation='relu'))
model.add(Dense(12, activation='relu'))
model.add(Dense(12, activation='relu'))
model.add(Dense(1, activation='sigmoid'))
model.compile(loss=customLoss, optimizer='adam', metrics=['accuracy'])
model.fit(X, y, epochs=15)
执行后,我得到这个ValueError
:
ValueError: Passed in object of type <class 'keras.engine.keras_tensor.KerasTensor'>, not tf.Tensor
当我将 loss=customLoss
更改为 loss='mse'
时,模型开始训练,但使用 customLoss
才是重点。有什么想法吗?
问题似乎出在损失函数中的model.input,如果我对你的代码理解正确,你可以使用损失:
def custom_loss_pass(model, x_tensor):
def custom_loss(y_true,y_pred):
with tf.GradientTape() as t:
t.watch(x_tensor)
output = model(x_tensor)
DyDX = t.gradient(output, x_tensor)
dy_t = DyDX[:, 5:6]
R_pred=dy_t
# loss_data = tf.reduce_mean(tf.square(yTrue - yPred), axis=-1)
loss_PDE = tf.reduce_mean(tf.square(R_pred))
return loss_PDE
return custom_loss
然后:
model.compile(loss=custom_loss_pass(model, X), optimizer='adam', metrics=['accuracy'])
我不确定它是否满足您的要求,但至少它有效!