CNN 的 TensorFlow Xavier 初始值设定项的 numpy 等价物是什么?

What is the numpy equivalent of TensorFlow Xavier initializer for CNN?

我想在 NumPy 中重新创建 Xavier 初始化(使用基本函数),就像 TensorFlow2 为 CNN 所做的一样。 以下是我学习在 NumPy 中进行 Xavier 初始化的方法:

# weights.shape = (2,2)
np.random.seed(0)
nodes_in = 2*2
weights = np.random.rand(2,2) * np.sqrt(1/nodes_in)

>>>array([[0.27440675, 0.35759468],
          [0.30138169, 0.27244159]])

这就是我为逻辑回归模型学习 Xavier 初始化的方式。似乎对于卷积神经网络应该有所不同,但我不知道如何。

initializer = tf.initializers.GlorotUniform(seed=0)
tf.Variable(initializer(shape=[2,2],dtype=tf.float32))

>>><tf.Variable 'Variable:0' shape=(2, 2) dtype=float32, numpy=
   array([[-0.7078647 ,  0.50461936],
          [ 0.73500216,  0.6633029 ]], dtype=float32)>

我对 TensorFlow documentation 解释 "fan_in" 和 "fan_out" 感到困惑。我猜这就是问题所在。有人可以帮我把它简化一下吗?

非常感谢!

[更新]:

当我按照 tf.keras.initializers.GlorotUniform 文档进行操作时,我仍然没有得到相同的结果:

# weights.shape = (2,2)
np.random.seed(0)
fan_in = 2*2
fan_out = 2*2
limit = np.sqrt(6/(fan_in + fan_out))
np.random.uniform(-limit,limit,size=(2,2))
>>>array([[0.08454747, 0.37271892],
          [0.17799139, 0.07773995]])

使用 Tensorflow

initializer = tf.initializers.GlorotUniform(seed=0)
tf.Variable(initializer(shape=[2,2],dtype=tf.float32))
<tf.Variable 'Variable:0' shape=(2, 2) dtype=float32, numpy=
array([[-0.7078647 ,  0.50461936],
       [ 0.73500216,  0.6633029 ]], dtype=float32)>

Numpy 中的相同逻辑

import math
np.random.seed(0)
scale = 1/max(1., (2+2)/2.)
limit = math.sqrt(3.0 * scale)
weights = np.random.uniform(-limit, limit, size=(2,2))
print(weights)
array([[0.11956818, 0.52710415],
       [0.25171784, 0.1099409 ]])

如果你观察,上面两个是不一样的,因为随机数生成器。 tensorflow 在内部使用无状态随机生成器,如下所示,如果您观察,我们会得到相同的输出。

tf.random.stateless_uniform(shape=(2,2),seed=[0, 0], minval=-limit, maxval=limit)
<tf.Tensor: shape=(2, 2), dtype=float32, numpy=
array([[-0.7078647 ,  0.50461936],
       [ 0.73500216,  0.6633029 ]], dtype=float32)>

如果需要了解更多内部实现,可以查看https://github.com/tensorflow/tensorflow/blob/2b96f3662bd776e277f86997659e61046b56c315/tensorflow/python/ops/init_ops_v2.py#L525