解读tensorflow的FLOPs profile结果

Interpreting the FLOPs profile result of tensorflow

我想剖析一个非常简单的神经网络模型的FLOPs,用于对MNIST数据集进行分类,batch size为128。按照官方教程,我得到了以下模型的结果, 但我无法理解输出的某些部分。

w1 = tf.Variable(tf.random_uniform([784, 15]), name='w1')
w2 = tf.Variable(tf.random_uniform([15, 10]), name='w2')
b1 = tf.Variable(tf.zeros([15, ]), name='b1')
b2 = tf.Variable(tf.zeros([10, ]), name='b2')

hidden_layer = tf.add(tf.matmul(images_iter, w1), b1)
logits = tf.add(tf.matmul(hidden_layer, w2), b2)

loss_op = tf.reduce_sum(\
    tf.nn.softmax_cross_entropy_with_logits(logits=logits, 
                                            labels=labels_iter))
opetimizer = tf.train.AdamOptimizer(learning_rate=0.01)
train_op = opetimizer.minimize(loss_op)

images_iterlabels_iter是tf.data的迭代器,类似于占位符。

tf.profiler.profile(
    tf.get_default_graph(),
    options=tf.profiler.ProfileOptionBuilder.float_operation())

我使用这段代码(相当于 tfprof 注释行工具中的 scope -min_float_ops 1 -select float_ops -account_displayed_op_only)来分析 FLOP 并得到以下结果。

Profile:
node name | # float_ops
_TFProfRoot (--/23.83k flops)
  random_uniform (11.76k/23.52k flops)
    random_uniform/mul (11.76k/11.76k flops)
    random_uniform/sub (1/1 flops)
  random_uniform_1 (150/301 flops)
    random_uniform_1/mul (150/150 flops)
    random_uniform_1/sub (1/1 flops)
  Adam/mul (1/1 flops)
  Adam/mul_1 (1/1 flops)
  softmax_cross_entropy_with_logits_sg/Sub (1/1 flops)
  softmax_cross_entropy_with_logits_sg/Sub_1 (1/1 flops)
  softmax_cross_entropy_with_logits_sg/Sub_2 (1/1 flops)

我的问题是

  1. 括号中的数字是什么意思?例如,random_uniform_1 (150/301 flops),150和301是什么?
  2. 为什么_TFProfRoot括号中的第一个数字是“--”?
  3. 为什么Adam/mul和softmax_cross_entropy_with_logits_sg/Sub的翻牌是1?

我知道一个问题读这么久很泄气,但是一个绝望的男孩,在官方文档中找不到相关信息,需要你们的帮助。

我试试看:

(1) 从这个例子来看,第一个数字是"self" flops,第二个数字表示命名范围下的"total" flops。例如:对于random_uniform(如果有这样的节点)、random_uniform/mul、random_uniform/sub这3个节点,分别取11.76k、11.76k、1个flops,在总共 23.52k 翻牌。

再举个例子:23.83k = 23.52k + 300.

这有意义吗?

(2) 根节点是分析器添加的 "virtual" 顶级节点,它没有 "self" flops ,或者换句话说,它的自身 flops 为零.

(3) 不确定为什么是 1。如果您可以使用 print(sess.graph_def)

打印 GraphDef 并找出该节点的真正含义,将会有所帮助

希望对您有所帮助。