Caffe SigmoidCrossEntropyLoss Layer 多标签分类c++
Caffe SigmoidCrossEntropyLoss Layer Multilabel classification c++
我有一个获取 2 个输入图像的网络,这两个图像属于 9 个 class 中的不止一个 class。我见过的所有示例 - 在 Caffe 文档中 - 直接从 prototxt 加载输入图像,但是我通过我的 C++ 代码输入信息。
我的输入层如下所示
input: "data"
input_shape{dim:20 dim:6 dim:100 dim:100}
input: "class_label"
input_shape{dim:20 dim:9}
损失层如下所示
layer {
name: "classes"
type: "InnerProduct"
bottom: "ip2"
top: "classes"
param { lr_mult: 1 }
param { lr_mult: 2 }
inner_product_param {
num_output: 9
weight_filler { type: "xavier" }
bias_filler { type: "constant" }
}
}
layer {
name: "class_loss"
type: "SigmoidCrossEntropyLoss"
bottom: "classes"
bottom: "class_label"
top: "class_loss"
}
我的假设是输入应该是一个看起来像这样的流
[0 0 1 0 1 0 1 0 0],其中 1 表示图像属于 class,0 表示不属于,这是真的吗?
我的第二个问题是,我应该从 SigmoidCrossEntropyLoss 层的输出中得到什么(例如 SoftmaxWithLoss 输出概率)?
你是对的:你案例中的标签应该是二进制 9 向量。
损失层的输出是标量损失值。当您训练您的网络时,您应该期望这个值会降低。对于预测(测试时间),您应该用简单的 sigmoid 层替换 sigmoid 损失层。某些 sigmoid 层的输出是一个 9 向量,每个条目代表相应 class.
存在的概率
deploy.prototxt
中的输出层应该类似于:
layer {
type: "Sigmoid"
name: "class_prob"
bottom: "classes"
top: "class_prob"
}
我有一个获取 2 个输入图像的网络,这两个图像属于 9 个 class 中的不止一个 class。我见过的所有示例 - 在 Caffe 文档中 - 直接从 prototxt 加载输入图像,但是我通过我的 C++ 代码输入信息。
我的输入层如下所示
input: "data"
input_shape{dim:20 dim:6 dim:100 dim:100}
input: "class_label"
input_shape{dim:20 dim:9}
损失层如下所示
layer {
name: "classes"
type: "InnerProduct"
bottom: "ip2"
top: "classes"
param { lr_mult: 1 }
param { lr_mult: 2 }
inner_product_param {
num_output: 9
weight_filler { type: "xavier" }
bias_filler { type: "constant" }
}
}
layer {
name: "class_loss"
type: "SigmoidCrossEntropyLoss"
bottom: "classes"
bottom: "class_label"
top: "class_loss"
}
我的假设是输入应该是一个看起来像这样的流 [0 0 1 0 1 0 1 0 0],其中 1 表示图像属于 class,0 表示不属于,这是真的吗?
我的第二个问题是,我应该从 SigmoidCrossEntropyLoss 层的输出中得到什么(例如 SoftmaxWithLoss 输出概率)?
你是对的:你案例中的标签应该是二进制 9 向量。
损失层的输出是标量损失值。当您训练您的网络时,您应该期望这个值会降低。对于预测(测试时间),您应该用简单的 sigmoid 层替换 sigmoid 损失层。某些 sigmoid 层的输出是一个 9 向量,每个条目代表相应 class.
存在的概率deploy.prototxt
中的输出层应该类似于:layer { type: "Sigmoid" name: "class_prob" bottom: "classes" top: "class_prob" }