具有附加参数的tensorflow自定义损失函数
tensorflow custom loss function with additional parameters
我了解自定义损失函数在 tensorflow 中的工作原理。假设在下面的代码中,a 和 b 是数字。
def customLoss( a,b):
def loss(y_true,y_pred):
loss=tf.math.reduce_mean(a*y_pred + b*y_pred)
return loss
return loss
但是如果 a 和 b 是与 y_pred 具有相同形状的数组呢?
比方说
y_pred= np.array([0,1,0,1])
a= np.arange(4)
b= np.ones(4)
我想损失函数的值等于6:
np.mean(a*y_pred + b*y_pred) #element-wise.
我觉得我的损失函数现在是错误的。它应该是每个样本的两个额外输入或权重。有人可以帮忙吗?谢谢。
假设 a
和 b
是所有损失计算中的固定数字,您可以执行类似于原始损失函数的操作:
import numpy as np
import tensorflow as tf
y_pred = np.array([0, 1, 0, 1])
y_true = np.array([0, 1, 0, 1])
a = np.arange(4)
b = np.ones(4)
def get_custom_loss(a, b):
a = tf.constant(a, dtype=tf.float32)
b = tf.constant(b, dtype=tf.float32)
def loss_fn(y_true, y_pred):
y_true = tf.cast(y_true, dtype=tf.float32)
y_pred = tf.cast(y_pred, dtype=tf.float32)
loss = tf.math.reduce_sum(
[tf.math.multiply(a, y_pred), tf.math.multiply(b, y_pred)])
return loss
return loss_fn
loss_fn = get_custom_loss(a, b)
print(loss_fn(y_pred, y_true))
tf.math.multiply(foo, bar)
将执行两个张量的逐元素乘法(参见 docs)。我假设您还想对两个产品的结果求和,而不是取平均值(即 1.5)。
请注意:您目前没有在损失中使用 y_true
。
我了解自定义损失函数在 tensorflow 中的工作原理。假设在下面的代码中,a 和 b 是数字。
def customLoss( a,b):
def loss(y_true,y_pred):
loss=tf.math.reduce_mean(a*y_pred + b*y_pred)
return loss
return loss
但是如果 a 和 b 是与 y_pred 具有相同形状的数组呢? 比方说
y_pred= np.array([0,1,0,1])
a= np.arange(4)
b= np.ones(4)
我想损失函数的值等于6:
np.mean(a*y_pred + b*y_pred) #element-wise.
我觉得我的损失函数现在是错误的。它应该是每个样本的两个额外输入或权重。有人可以帮忙吗?谢谢。
假设 a
和 b
是所有损失计算中的固定数字,您可以执行类似于原始损失函数的操作:
import numpy as np
import tensorflow as tf
y_pred = np.array([0, 1, 0, 1])
y_true = np.array([0, 1, 0, 1])
a = np.arange(4)
b = np.ones(4)
def get_custom_loss(a, b):
a = tf.constant(a, dtype=tf.float32)
b = tf.constant(b, dtype=tf.float32)
def loss_fn(y_true, y_pred):
y_true = tf.cast(y_true, dtype=tf.float32)
y_pred = tf.cast(y_pred, dtype=tf.float32)
loss = tf.math.reduce_sum(
[tf.math.multiply(a, y_pred), tf.math.multiply(b, y_pred)])
return loss
return loss_fn
loss_fn = get_custom_loss(a, b)
print(loss_fn(y_pred, y_true))
tf.math.multiply(foo, bar)
将执行两个张量的逐元素乘法(参见 docs)。我假设您还想对两个产品的结果求和,而不是取平均值(即 1.5)。
请注意:您目前没有在损失中使用 y_true
。