tensorflow variable.assign_add 函数在这个例子中很神秘

tensorflow variable.assign_add function is mysterious in this example

我正在尝试通过在线工作示例学习 tensorflow,但遇到了这个示例,我真的很想知道它是如何工作的。任何人都可以解释 tensorflow 的这个特定函数背后的数学原理以及 [ns] 如何从布尔数据类型中获取它的值。

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

Y, X = np.mgrid[-2.3:2.3:0.005, -5:5:0.005]
Z = X+1j*Y

c = tf.constant(Z, np.complex64)#.astype(np.complex64))
zs = tf.Variable(c)
ns = tf.Variable(tf.zeros_like(c, tf.float32))

sess = tf.InteractiveSession()

tf.global_variables_initializer().run()

zs_ = zs*zs + c

not_diverged = tf.abs(zs_) > 4

step = tf.group(zs.assign(zs_),
 ns.assign_add(tf.cast(not_diverged, tf.float32)))

nx = tf.reduce_sum(ns)
zx = tf.reduce_sum(zs_)
cx = tf.reduce_sum(c)
zf = tf.reduce_all(not_diverged)

for i in range(200): 
    step.run()
    print(sess.run([nx,zx,cx,zf]))

plt.imshow(ns.eval())
plt.show()
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

# this defines the complex plane
Y, X = np.mgrid[-2.3:2.3:0.005, -5:5:0.005]
Z = X+1j*Y
c = tf.constant(Z, np.complex64)

# tensors are immutable in tensorflow,
# but variabels arent, so use variable
# to update values later on
zs = tf.Variable(c)

# ns will keep count of what has diverged
ns = tf.Variable(tf.zeros_like(c, tf.float32))

sess = tf.InteractiveSession()
tf.global_variables_initializer().run()

# mandlebrot set M is defined as
# c \in M \iff |P_c^n(0)| <= 2 \iff abs(P_c^n(0)) <= 4
# where P_c(z) = z^2 + c
# the variable name is confusing, as it is actually
# the opposite, I renamed it below
zs_ = zs*zs + c
diverged = tf.abs(zs_) > 4

# ns gets its value as a bool casted to a float
# is given by True \mapsto 1., False \mapsto 0.
# the assign add just says, add tf.cast(diverged, tf.float32)
# to the variabel ns, and assign that value to the variable
step = tf.group(
    zs.assign(zs_),
    ns.assign_add(tf.cast(diverged, tf.float32)))


# here we iterate n to whatever we like
# each time we are moving further along the
# sequence P^n_c(0), which must be bounded
# in a disk of radius 2 to be in M
for i in range(200):
    step.run()

# anywhere with value > 0 in the plot is not in the Mandlebrot set
# anywhere with value = 0 MIGHT be in the Mandlebrot set
# we don't know for sure if it is in the set, 
# because we can only ever take n to be some
# finite number. But to be in the Mandlebrot set, it has
# to be bounded for all n!
plt.imshow(ns.eval())
plt.show()