使用 tf.Session() 时未训练模型

Model is not been training when tf.Session() is used

我是 TensorFlow 和 Keras 的新手。我正在尝试使用 TF 1.x(使用此 repo https://github.com/hse-aml)来理解 GAN,但我在使用以下用于创建会话的函数时遇到了问题。我的问题是这个函数到底在做什么(为什么我们不能单独使用 tf.Session() )。当我使用 tf.Session() 时,模型没有被训练。

from keras import backend as K

def weird_session():
    curr_session = tf.get_default_session()
    if curr_session is not None:
        curr_session.close()
    K.clear_session()
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    s = tf.InteractiveSession(config=config)
    K.set_session(s)
    return s
s=weird_session()

这是我用过的完整代码。

%tensorflow_version 1.x
import matplotlib.pyplot as plt
%matplotlib inline
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras import layers as L
import numpy as np


IMG_SHAPE=(36,36,3)
####################################################################################### DOWNLOAD DATASET
!git clone https://github.com/RaviSoji/colab_utils.git  # Include the "!".
import colab_utils
drive = colab_utils.get_gdrive()
colab_utils.pull_from_gdrive(drive, 'GAN/my.npy','hah.npy')
dataset=np.load('hah.npy')
plt.imshow(dataset[0])

data = np.float32(dataset)/255.
########################################################################################

from keras import backend as K

####################################################  weird_session() function -the problem
def weird_session():
    curr_session = tf.get_default_session()
    if curr_session is not None:
        curr_session.close()
    K.clear_session()
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    s = tf.InteractiveSession(config=config)
    K.set_session(s)
    return s
##################################################
s=weird_session()


IMG_SHAPE = data.shape[1:]

CODE_SIZE = 256

generator = Sequential()
generator.add(L.InputLayer([CODE_SIZE],name='noise'))
generator.add(L.Dense(10*8*8, activation='elu'))
generator.add(L.Reshape((8,8,10)))
generator.add(L.Conv2DTranspose(64,kernel_size=(5,5),activation='elu'))
generator.add(L.Conv2DTranspose(64,kernel_size=(5,5),activation='elu'))
generator.add(L.UpSampling2D(size=(2,2)))
generator.add(L.Conv2DTranspose(32,kernel_size=3,activation='elu'))
generator.add(L.Conv2DTranspose(32,kernel_size=3,activation='elu'))
generator.add(L.Conv2DTranspose(32,kernel_size=3,activation='elu'))
generator.add(L.Conv2D(3,kernel_size=3,activation=None))

discriminator = Sequential()
discriminator.add(L.InputLayer(IMG_SHAPE))
discriminator.add(L.Conv2D(filters=16, kernel_size=(3, 3), strides=1))
discriminator.add(L.LeakyReLU(0.12))
discriminator.add(L.Conv2D(filters=32, kernel_size=(3, 3), strides=1))
discriminator.add(L.LeakyReLU(0.12))
discriminator.add(L.MaxPool2D(pool_size=(2, 2)))
discriminator.add(L.Conv2D(filters=32, kernel_size=(3, 3), strides=1))
discriminator.add(L.LeakyReLU(0.12))
discriminator.add(L.Conv2D(filters=64, kernel_size=(3, 3), strides=1))
discriminator.add(L.LeakyReLU(0.12))
discriminator.add(L.MaxPool2D(pool_size=(2, 2))) 
discriminator.add(L.Flatten())
discriminator.add(L.Dense(256,activation='tanh'))
discriminator.add(L.Dense(2,activation=tf.nn.log_softmax))


noise = tf.placeholder('float32',[None,CODE_SIZE])
real_data = tf.placeholder('float32',[None,]+list(IMG_SHAPE))
logp_real = discriminator(real_data)
generated_data = generator(noise)
logp_gen = discriminator(generated_data)

d_loss = -tf.reduce_mean(logp_real[:,1] + logp_gen[:,0])
d_loss += tf.reduce_mean(discriminator.layers[-1].kernel**2)
disc_optimizer =  tf.train.GradientDescentOptimizer(1e-3).minimize(d_loss,var_list=discriminator.trainable_weights)

g_loss = -tf.reduce_mean(logp_gen[:,1])
gen_optimizer = tf.train.AdamOptimizer(1e-4).minimize(g_loss,var_list=generator.trainable_weights)

s.run(tf.global_variables_initializer())


def sample_noise_batch(bsize):
    return np.random.normal(size=(bsize, CODE_SIZE)).astype('float32')

def sample_data_batch(bsize):
    idxs = np.random.choice(np.arange(data.shape[0]), size=bsize)
    return data[idxs]

def sample_images(nrow,ncol, sharp=False):
    images = generator.predict(sample_noise_batch(bsize=nrow*ncol))
    if np.var(images)!=0:
        images = images.clip(np.min(data),np.max(data))
    for i in range(nrow*ncol):
        plt.subplot(nrow,ncol,i+1)
        if sharp:
            plt.imshow(images[i].reshape(IMG_SHAPE),cmap="gray", interpolation="none")
        else:
            plt.imshow(images[i].reshape(IMG_SHAPE),cmap="gray")
    plt.show()

def sample_probas(bsize):
    plt.title('Generated vs real data')
    plt.hist(np.exp(discriminator.predict(sample_data_batch(bsize)))[:,1],
             label='D(x)', alpha=0.5,range=[0,1])
    plt.hist(np.exp(discriminator.predict(generator.predict(sample_noise_batch(bsize))))[:,1],
             label='D(G(z))',alpha=0.5,range=[0,1])
    plt.legend(loc='best')
    plt.show()

from IPython import display

for epoch in range(50000):
    feed_dict = {
        real_data:sample_data_batch(100),
        noise:sample_noise_batch(100)
    }

    for i in range(5):
        s.run(disc_optimizer,feed_dict)
    s.run(gen_optimizer,feed_dict)

    if epoch %100==0:
        display.clear_output(wait=True)
        sample_images(2,3,True)

文档 here

SessionInteractiveSession 之间的唯一区别是 InteractiveSession 使自己成为默认会话,这样您就可以调用 run()eval() 而无需显式调用会话。

如果您在 python shell 或 Jupyter notebooks 中试验 TF,这会很有帮助,因为它避免了将显式 Session 对象传递给 运行 操作。

因此,如果您只使用 tf.Session,则还需要将其设置为默认会话。