计算Softmax时得到NaN和Infinity
Get the NaN and Infinity when calculating the Softmax
我正在尝试在 android 中实现 Softmax 函数。
这里是来自Whosebug的原始softmax函数供参考:
private double softmax(double input, double[] neuronValues) {
double total = Arrays.stream(neuronValues).map(Math::exp).sum();
return Math.exp(input) / total;
}
然后,我在 java Main:
中测试了 softmax
函数
public static void main(String[] args) {
try{
double input = Double.valueOf("123456789");
double[] doubleValues = new double[2];
doubleValues[0] = Double.valueOf("123456789");
doubleValues[1] = Double.valueOf("234567890");
double total = 0;
for (int i = 0; i < doubleValues.length; i++) {
double value = Math.exp(doubleValues[i]);
total += value;
}
double result = Math.exp(input) / total;
System.out.println(String.format("total: %s, result: %s", total, result));
}catch (Throwable ex) {
ex.printStackTrace();
}
}
输出为:
total: Infinity, result: NaN
好像返回的softmax函数是NaN
,不在[0,1].
范围内
据我了解,softmax函数应该将任何数字转换到[0,1]的范围内。
有什么问题?
您的数字太大,所以它的指数超出了 double 可以处理的范围(溢出)。 100 的指数具有 43 的数量级,因此 123456789 的指数将趋于无穷大。
total
是 double.POSITIVE_INFINITY。 result
是 inf / inf 所以它是 NaN。
尝试将您的输入标准化到一个范围,例如,min-max 标准化将输入转换为 [-1,1] 或 [0,-1] 的范围。这些范围通常用于机器学习,因为它们的幂级数是有界的。
我正在尝试在 android 中实现 Softmax 函数。
这里是来自Whosebug的原始softmax函数供参考:
private double softmax(double input, double[] neuronValues) {
double total = Arrays.stream(neuronValues).map(Math::exp).sum();
return Math.exp(input) / total;
}
然后,我在 java Main:
中测试了softmax
函数
public static void main(String[] args) {
try{
double input = Double.valueOf("123456789");
double[] doubleValues = new double[2];
doubleValues[0] = Double.valueOf("123456789");
doubleValues[1] = Double.valueOf("234567890");
double total = 0;
for (int i = 0; i < doubleValues.length; i++) {
double value = Math.exp(doubleValues[i]);
total += value;
}
double result = Math.exp(input) / total;
System.out.println(String.format("total: %s, result: %s", total, result));
}catch (Throwable ex) {
ex.printStackTrace();
}
}
输出为:
total: Infinity, result: NaN
好像返回的softmax函数是NaN
,不在[0,1].
据我了解,softmax函数应该将任何数字转换到[0,1]的范围内。
有什么问题?
您的数字太大,所以它的指数超出了 double 可以处理的范围(溢出)。 100 的指数具有 43 的数量级,因此 123456789 的指数将趋于无穷大。
total
是 double.POSITIVE_INFINITY。 result
是 inf / inf 所以它是 NaN。
尝试将您的输入标准化到一个范围,例如,min-max 标准化将输入转换为 [-1,1] 或 [0,-1] 的范围。这些范围通常用于机器学习,因为它们的幂级数是有界的。