dl4j lstm 神经网络的输出是什么?
What is the output of the dl4j lstm neural network?
我正在研究一个文本生成例子https://github.com/eclipse/deeplearning4j-examples/blob/master/dl4j-examples/src/main/java/org/deeplearning4j/examples/advanced/modelling/charmodelling/generatetext/GenerateTxtCharCompGraphModel.java。
lstm网络的输出是一个概率分布,按照我的理解,这是一个double数组,其中每个值表示的是数组中索引对应的字符出现的概率。所以我无法理解我们从分布中获取字符索引的以下代码:
/** Given a probability distribution over discrete classes, sample from the distribution
* and return the generated class index.
* @param distribution Probability distribution over classes. Must sum to 1.0
*/
static int sampleFromDistribution(double[] distribution, Random rng){
double d = 0.0;
double sum = 0.0;
for( int t=0; t<10; t++ ) {
d = rng.nextDouble();
sum = 0.0;
for( int i=0; i<distribution.length; i++ ){
sum += distribution[i];
if( d <= sum ) return i;
}
//If we haven't found the right index yet, maybe the sum is slightly
//lower than 1 due to rounding error, so try again.
}
//Should be extremely unlikely to happen if distribution is a valid probability distribution
throw new IllegalArgumentException("Distribution is invalid? d="+d+", sum="+sum);
}
我们似乎得到了一个随机值。为什么我们不直接选择价值最高的指数呢?如果我想 select 不是一个,而是两个或三个最有可能的下一个字符,我该怎么办?
这个函数从分布中采样,而不是简单地返回最可能的字符class。
这也意味着您得到的不是最有可能的字符,而是随机字符,其概率由给定的概率分布定义。
首先从均匀分布 (rng.nextDouble()
) 中获取一个介于 0 和 1 之间的随机值,然后找到该值在给定分布中的位置。
你可以想象它是这样的(如果你的字母表中只有 a 到 f):
[ a | b | c | d | e | f ]
0.0 0.3 0.5 1.0
如果抽取的随机值刚好超过 0.5,则会产生 e
,如果刚好小于 0.5,则会产生 d
。
每个字母根据其在分布中的权重在0和1之间的这条线上占据space的比例。
我正在研究一个文本生成例子https://github.com/eclipse/deeplearning4j-examples/blob/master/dl4j-examples/src/main/java/org/deeplearning4j/examples/advanced/modelling/charmodelling/generatetext/GenerateTxtCharCompGraphModel.java。 lstm网络的输出是一个概率分布,按照我的理解,这是一个double数组,其中每个值表示的是数组中索引对应的字符出现的概率。所以我无法理解我们从分布中获取字符索引的以下代码:
/** Given a probability distribution over discrete classes, sample from the distribution
* and return the generated class index.
* @param distribution Probability distribution over classes. Must sum to 1.0
*/
static int sampleFromDistribution(double[] distribution, Random rng){
double d = 0.0;
double sum = 0.0;
for( int t=0; t<10; t++ ) {
d = rng.nextDouble();
sum = 0.0;
for( int i=0; i<distribution.length; i++ ){
sum += distribution[i];
if( d <= sum ) return i;
}
//If we haven't found the right index yet, maybe the sum is slightly
//lower than 1 due to rounding error, so try again.
}
//Should be extremely unlikely to happen if distribution is a valid probability distribution
throw new IllegalArgumentException("Distribution is invalid? d="+d+", sum="+sum);
}
我们似乎得到了一个随机值。为什么我们不直接选择价值最高的指数呢?如果我想 select 不是一个,而是两个或三个最有可能的下一个字符,我该怎么办?
这个函数从分布中采样,而不是简单地返回最可能的字符class。
这也意味着您得到的不是最有可能的字符,而是随机字符,其概率由给定的概率分布定义。
首先从均匀分布 (rng.nextDouble()
) 中获取一个介于 0 和 1 之间的随机值,然后找到该值在给定分布中的位置。
你可以想象它是这样的(如果你的字母表中只有 a 到 f):
[ a | b | c | d | e | f ]
0.0 0.3 0.5 1.0
如果抽取的随机值刚好超过 0.5,则会产生 e
,如果刚好小于 0.5,则会产生 d
。
每个字母根据其在分布中的权重在0和1之间的这条线上占据space的比例。