具有多层的 TensorFlow 训练函数

TensorFlow train function with multiple layers

我是张量的新手,正在尝试理解它。我设法创建了一层模型。但我现在想再添加 2 个。我怎样才能使我的火车功能正常工作?我想用数百个值 X 和 Y 来训练它。我实现了我需要的所有值:每一层的权重和偏差,但我不明白如何在我的训练函数中使用它们。什么时候训练出来,我怎么用。就像我在代码的最后一部分所做的那样。

import numpy as np

print("TensorFlow version: {}".format(tf.__version__))
print("Eager execution: {}".format(tf.executing_eagerly()))

x = np.array([
    [10, 10, 30, 20],
])

y = np.array([[10, 1, 1, 1]])


class Model(object):
    def __init__(self, x, y):
        # get random values.
        self.W = tf.Variable(tf.random.normal((len(x), len(x[0]))))
        self.b = tf.Variable(tf.random.normal((len(y),)))
        self.W1 = tf.Variable(tf.random.normal((len(x), len(x[0]))))
        self.b1 = tf.Variable(tf.random.normal((len(y),)))
        self.W2 = tf.Variable(tf.random.normal((len(x), len(x[0]))))
        self.b2 = tf.Variable(tf.random.normal((len(y),)))

    def __call__(self, x):
        out1 = tf.multiply(x, self.W) + self.b
        out2 = tf.multiply(out1, self.W1) + self.b1
        last_layer = tf.multiply(out2, self.W2) + self.b2
        # Input_Leyer = self.W * x + self.b
        return last_layer


def loss(predicted_y, desired_y):
    return tf.reduce_sum(tf.square(predicted_y - desired_y))


optimizer = tf.optimizers.Adam(0.1)
    
def train(model, inputs, outputs):
    with tf.GradientTape() as t:
        current_loss = loss(model(inputs), outputs)
    grads = t.gradient(current_loss, [model.W, model.b])
    optimizer.apply_gradients(zip(grads, [model.W, model.b]))

    print(current_loss)


model = Model(x, y)

for i in range(10000):
    train(model, x, y)

for i in range(3): 
    InputX = np.array([
        [input(), input(), input(), input()],
    ])
    returning = tf.math.multiply(
        InputX, model.W, name=None
    )
    print("I think that output can be:", returning)

只需将新变量添加到列表中:

grads = t.gradient(current_loss, [model.W, model.b, model.W1, model.b1, model.W2, model.b2])
optimizer.apply_gradients(zip(grads, [model.W, model.b, model.W1, model.b1, model.W2, model.b2]))