Tensorflow 中值

Tensorflow median value

如何计算tensorflow中列表的中值? 喜欢

node = tf.median(X)

X 是占位符
在numpy中,我可以直接使用np.median来获取中值。如何在tensorflow中使用numpy运算?

编辑: 此答案已过时,请改用 Lucas Venezian Povoa 的解决方案。更简单更快。

您可以使用以下方法计算 tensorflow 中的中位数:

def get_median(v):
    v = tf.reshape(v, [-1])
    mid = v.get_shape()[0]//2 + 1
    return tf.nn.top_k(v, mid).values[-1]

如果 X 已经是向量,您可以跳过整形。

如果您关心中值是偶数大小向量的两个中间元素的平均值,您应该改用它:

def get_real_median(v):
    v = tf.reshape(v, [-1])
    l = v.get_shape()[0]
    mid = l//2 + 1
    val = tf.nn.top_k(v, mid).values
    if l % 2 == 1:
        return val[-1]
    else:
        return 0.5 * (val[-1] + val[-2])

要使用 tensorflow 计算数组的中位数,您可以使用 percentile 函数,因为第 50 个百分位数是中位数。

import tensorflow as tf
import tensorflow_probability as tfp
import numpy as np 

np.random.seed(0)   
x = np.random.normal(3.0, .1, 100)

median = tfp.stats.percentile(x, 50.0, interpolation='midpoint')

tf.Session().run(median)

上面的代码相当于np.percentile(x, 50, interpolation='midpoint').

我们可以修改 BlueSun 的解决方案,使其在 GPU 上更快:

def get_median(v):
    v = tf.reshape(v, [-1])
    m = v.get_shape()[0]//2
    return tf.reduce_min(tf.nn.top_k(v, m, sorted=False).values)

这与(根据我的经验)使用 tf.contrib.distributions.percentile(v, 50.0) 和 returns 实际元素之一一样快。