基本 Tensorflow 示例 - 直线预测

Basic Tensorflow Example - Prediction of a Line

我正在尝试使用 Tensorflow 创建这个超级简单的示例,但我显然没有完全理解 Tensorflow 的 API。

我有以下代码。它最初不是我的 - 我从一些演示中找到它,但我不记得我在哪里找到它,否则我会给予作者荣誉。道歉。

保存训练好的线模型

import tensorflow as tf
import numpy as np

# Create 100 phony x, y data points in NumPy, y = x * 0.1 + 0.3
x_data = np.random.rand(100).astype(np.float32)
y_data = x_data * 0.1 + 0.3

# Try to find values for W and b that compute y_data = W * x_data + b
W = tf.Variable(tf.random_uniform([1], -1.0, 1.0), name='W')
b = tf.Variable(tf.zeros([1]), name='b')
y = W * x_data + b

# Minimize the mean squared errors.
loss = tf.reduce_mean(tf.square(y - y_data))
optimizer = tf.train.GradientDescentOptimizer(0.5)
train = optimizer.minimize(loss)

# Before starting, initialize the variables.  We will 'run' this first.
init = tf.global_variables_initializer()

# Create a session saver
saver = tf.train.Saver()

# Launch the graph.
sess = tf.Session() 

sess.run(init)

# Fit the line.
for step in range(201):
    sess.run(train)
    if step % 20 == 0:
        print(step, sess.run(W), sess.run(b))
        saver.save(sess, 'linemodel')

好的,没关系。我只想加载模型,然后查询我的模型以获得预测值。这是我尝试的代码:

加载和查询训练好的线模型

# This is going to load the line model
import tensorflow as tf

sess = tf.Session()
new_saver = tf.train.import_meta_graph('linemodel.meta')
new_saver.restore(sess, tf.train.latest_checkpoint('./')) # latest checkpoint
all_vars = tf.global_variables()
for v in all_vars:
    v_ = sess.run(v)
    print("This is {} with value: {}".format(v.name, v_))
    # this works


# None of the below works
# Tried this as well
#fetches = {
#   "input": tf.constant(10, name='input')
#}

#feed_dict = {"input": tf.constant(10, name='input')}
#vals = sess.run(fetches, feed_dict = feed_dict)
# Tried this and it didn't work
# query_value = tf.constant(10, name='query')

# print(sess.run(query_value))

这是一个非常基本的问题,但我怎样才能只传入一个值并像函数一样使用我的行。我是否需要更改线模型的构建方式?我的猜测是计算图没有设置在输出是我们可以获得的实际变量的地方。这个对吗?如果可以,我该如何修改这个程序?

您必须重新创建张量流图并将保存的权重加载到其中。我在您的代码中添加了几行,它提供了所需的输出。请检查一下。

import tensorflow as tf
import numpy as np

sess = tf.Session() 
new_saver = tf.train.import_meta_graph('linemodel.meta')
new_saver.restore(sess, tf.train.latest_checkpoint('./')) # latest checkpoint
all_vars = tf.global_variables()

# load saved weights into new variables
W = all_vars[0]
b = all_vars[1]

# build TF graph
x = tf.placeholder(tf.float32)
y = tf.add(tf.multiply(W,x),b)

# Session
init = tf.global_variables_initializer()
print(sess.run(all_vars))
sess.run(init)    
for i in range(2):
    x_ip = np.random.rand(10).astype(np.float32) # batch_size : 10
    vals = sess.run(y,feed_dict={x:x_ip})
    print vals

输出:

[array([ 0.1000001], dtype=float32), array([ 0.29999995], dtype=float32)]

[-0.21707924 -0.18646611 -0.00732027 -0.14248954 -0.54388255 -0.33952206  -0.34291503 -0.54771954 -0.60995424 -0.91694558]
[-0.45050886 -0.01207681 -0.38950539 -0.25888413 -0.0103816  -0.10003483 -0.04783082 -0.83299863 -0.53189355 -0.56571382]

希望对您有所帮助。