Keras:编译模型后更新“可训练”属性
Keras: Update `trainable` attribute after compiling model
我在 Keras 中有一个条件 GAN (CGAN) 模型:
from keras.datasets import mnist
from keras.layers import Input, Dense, Reshape, Flatten, Embedding, BatchNormalization, Dropout, multiply
from keras.layers.advanced_activations import LeakyReLU
from keras.models import Sequential, Model
from keras.optimizers import Adam
import matplotlib.pyplot as plt
import tensorflow as tf
import keras as K
import numpy as np
import sys, os
import warnings
warnings.filterwarnings('ignore')
if not os.path.exists('images'): os.makedirs('images')
class GAN(object):
def __init__(self, width=28, height=28, channels=1, latent_dim=100, lr=0.0002):
self.WIDTH = int(width) # width of input images
self.HEIGHT = int(height) # height of input images
self.CHANNELS = int(channels) # n color channels in images
self.SHAPE = (self.WIDTH, self.HEIGHT, self.CHANNELS)
self.LATENT_DIM = latent_dim # length of vector used to model latent space (= noise)
self.N_CLASSES = 10 # total number of possible classes in the data
self.OPTIMIZER = Adam(lr, 0.5)
# generator
self.G = self.generator()
self.G.compile(loss='binary_crossentropy', optimizer=self.OPTIMIZER)
# discriminator
self.D = self.discriminator()
self.D.compile(loss='binary_crossentropy', optimizer=self.OPTIMIZER, metrics=['accuracy'])
self.D.trainable = False # prevent stacked D from training; https://github.com/eriklindernoren/Keras-GAN/issues/73
# stacked generator + discriminator
self.stacked_G_D = self.stacked_G_D()
self.stacked_G_D.compile(loss='binary_crossentropy', optimizer=self.OPTIMIZER)
def generator(self):
noise = Input((self.LATENT_DIM,), name='generator_noise') # allows g to create different outputs
label = Input((1,), name='generator_conditional', dtype='int32') # allows g to create samples from one class
# embed label in size of latent dimension
h = Embedding(self.N_CLASSES, self.LATENT_DIM, input_length=1)(label)
label_embedding = Flatten()(h)
# unified model
h = multiply([noise, label_embedding])
h = Dense(256)(h)
h = LeakyReLU(alpha=0.2)(h)
h = BatchNormalization(momentum=0.8)(h)
h = Dense(512)(h)
h = LeakyReLU(alpha=0.2)(h)
h = BatchNormalization(momentum=0.8)(h)
h = Dense(1024)(h)
h = LeakyReLU(alpha=0.2)(h)
h = BatchNormalization(momentum=0.8)(h)
h = Dense(np.prod(self.SHAPE), activation='tanh')(h)
o = Reshape((self.WIDTH, self.HEIGHT, self.CHANNELS))(h)
model = Model(inputs=[noise, label], outputs=[o])
model.summary()
return model
def discriminator(self):
image = Input((self.SHAPE))
label = Input((1,), dtype='int32')
# embed the label in the shape of an image (flattened)
h = Embedding(self.N_CLASSES, np.prod(self.SHAPE), input_length=1)(label)
label_embedding = Flatten()(h)
# parse out the image
img = Flatten()(image)
# unified model
h = multiply([img, label_embedding])
h = Dense(512)(h)
h = LeakyReLU(alpha=0.2)(h)
h = Dense(512)(h)
h = LeakyReLU(alpha=0.2)(h)
h = Dropout(0.4)(h)
h = Dense(512)(h)
h = LeakyReLU(alpha=0.2)(h)
h = Dropout(0.4)(h)
o = Dense(1, activation='sigmoid')(h)
model = Model(inputs=[image, label], outputs=[o])
model.summary()
return model
def stacked_G_D(self):
noise = Input((self.LATENT_DIM,)) # noise input
label = Input((1,)) # conditional input
img = self.G([noise, label])
valid = self.D([img, label])
model = Model(inputs=[noise, label], outputs=[valid])
model.summary()
return model
def train(self, X_train, Y_train, epochs=20000, batch=32, save_interval=100):
for i in range(epochs):
# train the discriminator
idx = np.random.randint(0, X_train.shape[0], batch)
imgs, labels = X_train[idx], Y_train[idx]
noise = np.random.normal(0, 1, (batch, self.LATENT_DIM))
fake_imgs = self.G.predict([noise, labels])
d_loss_real = self.D.train_on_batch([imgs, labels], np.ones((batch, 1)))
d_loss_fake = self.D.train_on_batch([fake_imgs, labels], np.zeros((batch, 1)))
d_loss = (np.add(d_loss_real, d_loss_fake)) * 0.5
# train the generator
sample_labels = np.random.randint(0, 10, batch).reshape(batch, 1)
g_loss = self.stacked_G_D.train_on_batch([noise, sample_labels], np.ones((batch, 1)))
if i % save_interval == 0:
print('epoch: {0} - disc loss: {1}, disc accuracy: {2}, gen loss: {2}'.format(i, d_loss[0], d_loss[1]*100, g_loss))
filename = 'mnist_{0}-{1}-{2}.png'.format(i, d_loss[0], g_loss)
self.plot_images(save_to_disk=True, filename=filename)
def plot_images(self, save_to_disk=False, n_images=10, filename=None, rows=2, size_scalar=4, class_arr=None):
if not filename: filename = 'mnist.png'
noise = np.random.normal(0, 1, (n_images, self.LATENT_DIM))
classes = class_arr if class_arr is not None else np.arange(0, n_images) % self.N_CLASSES
images = self.G.predict([noise, classes])
cols = np.ceil(n_images/rows) # n_cols in grid
fig = plt.figure(figsize=(cols*size_scalar, rows*size_scalar))
for i in range(n_images):
ax = fig.add_subplot(rows, np.ceil(n_images/rows), i+1)
image = np.reshape(images[i], [28, 28])
plt.imshow(image)
fig.subplots_adjust(hspace=0, wspace=0)
if save_to_disk:
fig.savefig(os.path.join('images', filename))
plt.close('all')
else:
fig.show()
(X_train, Y_train), (_, _) = mnist.load_data()
X_train = (X_train.astype(np.float32) - 127.5) / 127.5 # rescale {-1 to 1}
X_train = np.expand_dims(X_train, axis=3)
gan = GAN()
gan.train(X_train, Y_train)
我的目标是定期冻结鉴别器,使其无法学习。 (这是一些实验性工作。)但是,在编译模型后,我找不到实际更新 gan.D
的 .trainable
属性的方法。我已经尝试定期手动改变属性,但不管鉴别器继续学习什么。
实际上是否可以在编译模型后更新模型的 trainable
属性?如果是这样,我将不胜感激如何完成此操作的简单示例!
啊,你可以在编译模型后更新模型上的.trainable
属性,你只需要重新编译模型:
from keras.datasets import mnist
from keras.layers import Input, Dense, Reshape, Flatten, Embedding, BatchNormalization, Dropout, multiply
from keras.layers.advanced_activations import LeakyReLU
from keras.models import Sequential, Model
from keras.optimizers import Adam
import matplotlib.pyplot as plt
import tensorflow as tf
import keras as K
import numpy as np
import sys, os
import warnings
warnings.filterwarnings('ignore')
if not os.path.exists('images'): os.makedirs('images')
class GAN(object):
def __init__(self, width=28, height=28, channels=1, latent_dim=100, lr=0.0002):
self.WIDTH = int(width) # width of input images
self.HEIGHT = int(height) # height of input images
self.CHANNELS = int(channels) # n color channels in images
self.SHAPE = (self.WIDTH, self.HEIGHT, self.CHANNELS)
self.LATENT_DIM = latent_dim # length of vector used to model latent space (= noise)
self.N_CLASSES = 10 # total number of possible classes in the data
self.OPTIMIZER = Adam(lr, 0.5)
# generator
self.G = self.generator()
self.G.compile(loss='binary_crossentropy', optimizer=self.OPTIMIZER)
# discriminator
self.D = self.discriminator()
self.D.trainable = False # normally this line follows the initial compilation of the D
self.D.compile(loss='binary_crossentropy', optimizer=self.OPTIMIZER, metrics=['accuracy'])
# stacked generator + discriminator
self.stacked_G_D = self.stacked_G_D()
self.stacked_G_D.compile(loss='binary_crossentropy', optimizer=self.OPTIMIZER)
def generator(self):
noise = Input((self.LATENT_DIM,), name='generator_noise') # allows g to create different outputs
label = Input((1,), name='generator_conditional', dtype='int32') # allows g to create samples from one class
# embed label in size of latent dimension
h = Embedding(self.N_CLASSES, self.LATENT_DIM, input_length=1)(label)
label_embedding = Flatten()(h)
# unified model
h = multiply([noise, label_embedding])
h = Dense(256)(h)
h = LeakyReLU(alpha=0.2)(h)
h = BatchNormalization(momentum=0.8)(h)
h = Dense(512)(h)
h = LeakyReLU(alpha=0.2)(h)
h = BatchNormalization(momentum=0.8)(h)
h = Dense(1024)(h)
h = LeakyReLU(alpha=0.2)(h)
h = BatchNormalization(momentum=0.8)(h)
h = Dense(np.prod(self.SHAPE), activation='tanh')(h)
o = Reshape((self.WIDTH, self.HEIGHT, self.CHANNELS))(h)
model = Model(inputs=[noise, label], outputs=[o])
model.summary()
return model
def discriminator(self):
image = Input((self.SHAPE))
label = Input((1,), dtype='int32')
# embed the label in the shape of an image (flattened)
h = Embedding(self.N_CLASSES, np.prod(self.SHAPE), input_length=1)(label)
label_embedding = Flatten()(h)
# parse out the image
img = Flatten()(image)
# unified model
h = multiply([img, label_embedding])
h = Dense(512)(h)
h = LeakyReLU(alpha=0.2)(h)
h = Dense(512)(h)
h = LeakyReLU(alpha=0.2)(h)
h = Dropout(0.4)(h)
h = Dense(512)(h)
h = LeakyReLU(alpha=0.2)(h)
h = Dropout(0.4)(h)
o = Dense(1, activation='sigmoid')(h)
model = Model(inputs=[image, label], outputs=[o])
model.summary()
return model
def stacked_G_D(self):
noise = Input((self.LATENT_DIM,)) # noise input
label = Input((1,)) # conditional input
img = self.G([noise, label])
valid = self.D([img, label])
model = Model(inputs=[noise, label], outputs=[valid])
model.summary()
return model
def train(self, X_train, Y_train, epochs=20000, batch=32, save_interval=100, toggle_D_trainable=None):
for i in range(epochs):
# train the discriminator
idx = np.random.randint(0, X_train.shape[0], batch)
imgs, labels = X_train[idx], Y_train[idx]
noise = np.random.normal(0, 1, (batch, self.LATENT_DIM))
fake_imgs = self.G.predict([noise, labels])
d_loss_real = self.D.train_on_batch([imgs, labels], np.ones((batch, 1)))
d_loss_fake = self.D.train_on_batch([fake_imgs, labels], np.zeros((batch, 1)))
d_loss = (np.add(d_loss_real, d_loss_fake)) * 0.5
# train the generator
sample_labels = np.random.randint(0, 10, batch).reshape(batch, 1)
g_loss = self.stacked_G_D.train_on_batch([noise, sample_labels], np.ones((batch, 1)))
if i % save_interval == 0:
print('epoch: {0} - disc loss: {1}, disc accuracy: {2}, gen loss: {2}'.format(i, d_loss[0], d_loss[1]*100, g_loss))
filename = 'mnist_{0}-{1}-{2}.png'.format(i, d_loss[0], g_loss)
self.plot_images(save_to_disk=True, filename=filename)
if i > 0 and toggle_D_trainable and i % toggle_D_trainable == 0:
self.D.trainable = False if self.D.trainable else True
self.D.compile(loss='binary_crossentropy', optimizer=self.OPTIMIZER, metrics=['accuracy'])
def plot_images(self, save_to_disk=False, n_images=10, filename=None, rows=2, size_scalar=4, class_arr=None):
if not filename: filename = 'mnist.png'
noise = np.random.normal(0, 1, (n_images, self.LATENT_DIM))
classes = class_arr if class_arr is not None else np.arange(0, n_images) % self.N_CLASSES
images = self.G.predict([noise, classes])
cols = np.ceil(n_images/rows) # n_cols in grid
fig = plt.figure(figsize=(cols*size_scalar, rows*size_scalar))
for i in range(n_images):
ax = fig.add_subplot(rows, np.ceil(n_images/rows), i+1)
image = np.reshape(images[i], [28, 28])
plt.imshow(image)
fig.subplots_adjust(hspace=0, wspace=0)
if save_to_disk:
fig.savefig(os.path.join('images', filename))
plt.close('all')
else:
fig.show()
(X_train, Y_train), (_, _) = mnist.load_data()
X_train = (X_train.astype(np.float32) - 127.5) / 127.5 # rescale {-1 to 1}
X_train = np.expand_dims(X_train, axis=3)
gan = GAN()
gan.train(X_train, Y_train, save_interval=100, toggle_D_trainable=1000)
我在 Keras 中有一个条件 GAN (CGAN) 模型:
from keras.datasets import mnist
from keras.layers import Input, Dense, Reshape, Flatten, Embedding, BatchNormalization, Dropout, multiply
from keras.layers.advanced_activations import LeakyReLU
from keras.models import Sequential, Model
from keras.optimizers import Adam
import matplotlib.pyplot as plt
import tensorflow as tf
import keras as K
import numpy as np
import sys, os
import warnings
warnings.filterwarnings('ignore')
if not os.path.exists('images'): os.makedirs('images')
class GAN(object):
def __init__(self, width=28, height=28, channels=1, latent_dim=100, lr=0.0002):
self.WIDTH = int(width) # width of input images
self.HEIGHT = int(height) # height of input images
self.CHANNELS = int(channels) # n color channels in images
self.SHAPE = (self.WIDTH, self.HEIGHT, self.CHANNELS)
self.LATENT_DIM = latent_dim # length of vector used to model latent space (= noise)
self.N_CLASSES = 10 # total number of possible classes in the data
self.OPTIMIZER = Adam(lr, 0.5)
# generator
self.G = self.generator()
self.G.compile(loss='binary_crossentropy', optimizer=self.OPTIMIZER)
# discriminator
self.D = self.discriminator()
self.D.compile(loss='binary_crossentropy', optimizer=self.OPTIMIZER, metrics=['accuracy'])
self.D.trainable = False # prevent stacked D from training; https://github.com/eriklindernoren/Keras-GAN/issues/73
# stacked generator + discriminator
self.stacked_G_D = self.stacked_G_D()
self.stacked_G_D.compile(loss='binary_crossentropy', optimizer=self.OPTIMIZER)
def generator(self):
noise = Input((self.LATENT_DIM,), name='generator_noise') # allows g to create different outputs
label = Input((1,), name='generator_conditional', dtype='int32') # allows g to create samples from one class
# embed label in size of latent dimension
h = Embedding(self.N_CLASSES, self.LATENT_DIM, input_length=1)(label)
label_embedding = Flatten()(h)
# unified model
h = multiply([noise, label_embedding])
h = Dense(256)(h)
h = LeakyReLU(alpha=0.2)(h)
h = BatchNormalization(momentum=0.8)(h)
h = Dense(512)(h)
h = LeakyReLU(alpha=0.2)(h)
h = BatchNormalization(momentum=0.8)(h)
h = Dense(1024)(h)
h = LeakyReLU(alpha=0.2)(h)
h = BatchNormalization(momentum=0.8)(h)
h = Dense(np.prod(self.SHAPE), activation='tanh')(h)
o = Reshape((self.WIDTH, self.HEIGHT, self.CHANNELS))(h)
model = Model(inputs=[noise, label], outputs=[o])
model.summary()
return model
def discriminator(self):
image = Input((self.SHAPE))
label = Input((1,), dtype='int32')
# embed the label in the shape of an image (flattened)
h = Embedding(self.N_CLASSES, np.prod(self.SHAPE), input_length=1)(label)
label_embedding = Flatten()(h)
# parse out the image
img = Flatten()(image)
# unified model
h = multiply([img, label_embedding])
h = Dense(512)(h)
h = LeakyReLU(alpha=0.2)(h)
h = Dense(512)(h)
h = LeakyReLU(alpha=0.2)(h)
h = Dropout(0.4)(h)
h = Dense(512)(h)
h = LeakyReLU(alpha=0.2)(h)
h = Dropout(0.4)(h)
o = Dense(1, activation='sigmoid')(h)
model = Model(inputs=[image, label], outputs=[o])
model.summary()
return model
def stacked_G_D(self):
noise = Input((self.LATENT_DIM,)) # noise input
label = Input((1,)) # conditional input
img = self.G([noise, label])
valid = self.D([img, label])
model = Model(inputs=[noise, label], outputs=[valid])
model.summary()
return model
def train(self, X_train, Y_train, epochs=20000, batch=32, save_interval=100):
for i in range(epochs):
# train the discriminator
idx = np.random.randint(0, X_train.shape[0], batch)
imgs, labels = X_train[idx], Y_train[idx]
noise = np.random.normal(0, 1, (batch, self.LATENT_DIM))
fake_imgs = self.G.predict([noise, labels])
d_loss_real = self.D.train_on_batch([imgs, labels], np.ones((batch, 1)))
d_loss_fake = self.D.train_on_batch([fake_imgs, labels], np.zeros((batch, 1)))
d_loss = (np.add(d_loss_real, d_loss_fake)) * 0.5
# train the generator
sample_labels = np.random.randint(0, 10, batch).reshape(batch, 1)
g_loss = self.stacked_G_D.train_on_batch([noise, sample_labels], np.ones((batch, 1)))
if i % save_interval == 0:
print('epoch: {0} - disc loss: {1}, disc accuracy: {2}, gen loss: {2}'.format(i, d_loss[0], d_loss[1]*100, g_loss))
filename = 'mnist_{0}-{1}-{2}.png'.format(i, d_loss[0], g_loss)
self.plot_images(save_to_disk=True, filename=filename)
def plot_images(self, save_to_disk=False, n_images=10, filename=None, rows=2, size_scalar=4, class_arr=None):
if not filename: filename = 'mnist.png'
noise = np.random.normal(0, 1, (n_images, self.LATENT_DIM))
classes = class_arr if class_arr is not None else np.arange(0, n_images) % self.N_CLASSES
images = self.G.predict([noise, classes])
cols = np.ceil(n_images/rows) # n_cols in grid
fig = plt.figure(figsize=(cols*size_scalar, rows*size_scalar))
for i in range(n_images):
ax = fig.add_subplot(rows, np.ceil(n_images/rows), i+1)
image = np.reshape(images[i], [28, 28])
plt.imshow(image)
fig.subplots_adjust(hspace=0, wspace=0)
if save_to_disk:
fig.savefig(os.path.join('images', filename))
plt.close('all')
else:
fig.show()
(X_train, Y_train), (_, _) = mnist.load_data()
X_train = (X_train.astype(np.float32) - 127.5) / 127.5 # rescale {-1 to 1}
X_train = np.expand_dims(X_train, axis=3)
gan = GAN()
gan.train(X_train, Y_train)
我的目标是定期冻结鉴别器,使其无法学习。 (这是一些实验性工作。)但是,在编译模型后,我找不到实际更新 gan.D
的 .trainable
属性的方法。我已经尝试定期手动改变属性,但不管鉴别器继续学习什么。
实际上是否可以在编译模型后更新模型的 trainable
属性?如果是这样,我将不胜感激如何完成此操作的简单示例!
啊,你可以在编译模型后更新模型上的.trainable
属性,你只需要重新编译模型:
from keras.datasets import mnist
from keras.layers import Input, Dense, Reshape, Flatten, Embedding, BatchNormalization, Dropout, multiply
from keras.layers.advanced_activations import LeakyReLU
from keras.models import Sequential, Model
from keras.optimizers import Adam
import matplotlib.pyplot as plt
import tensorflow as tf
import keras as K
import numpy as np
import sys, os
import warnings
warnings.filterwarnings('ignore')
if not os.path.exists('images'): os.makedirs('images')
class GAN(object):
def __init__(self, width=28, height=28, channels=1, latent_dim=100, lr=0.0002):
self.WIDTH = int(width) # width of input images
self.HEIGHT = int(height) # height of input images
self.CHANNELS = int(channels) # n color channels in images
self.SHAPE = (self.WIDTH, self.HEIGHT, self.CHANNELS)
self.LATENT_DIM = latent_dim # length of vector used to model latent space (= noise)
self.N_CLASSES = 10 # total number of possible classes in the data
self.OPTIMIZER = Adam(lr, 0.5)
# generator
self.G = self.generator()
self.G.compile(loss='binary_crossentropy', optimizer=self.OPTIMIZER)
# discriminator
self.D = self.discriminator()
self.D.trainable = False # normally this line follows the initial compilation of the D
self.D.compile(loss='binary_crossentropy', optimizer=self.OPTIMIZER, metrics=['accuracy'])
# stacked generator + discriminator
self.stacked_G_D = self.stacked_G_D()
self.stacked_G_D.compile(loss='binary_crossentropy', optimizer=self.OPTIMIZER)
def generator(self):
noise = Input((self.LATENT_DIM,), name='generator_noise') # allows g to create different outputs
label = Input((1,), name='generator_conditional', dtype='int32') # allows g to create samples from one class
# embed label in size of latent dimension
h = Embedding(self.N_CLASSES, self.LATENT_DIM, input_length=1)(label)
label_embedding = Flatten()(h)
# unified model
h = multiply([noise, label_embedding])
h = Dense(256)(h)
h = LeakyReLU(alpha=0.2)(h)
h = BatchNormalization(momentum=0.8)(h)
h = Dense(512)(h)
h = LeakyReLU(alpha=0.2)(h)
h = BatchNormalization(momentum=0.8)(h)
h = Dense(1024)(h)
h = LeakyReLU(alpha=0.2)(h)
h = BatchNormalization(momentum=0.8)(h)
h = Dense(np.prod(self.SHAPE), activation='tanh')(h)
o = Reshape((self.WIDTH, self.HEIGHT, self.CHANNELS))(h)
model = Model(inputs=[noise, label], outputs=[o])
model.summary()
return model
def discriminator(self):
image = Input((self.SHAPE))
label = Input((1,), dtype='int32')
# embed the label in the shape of an image (flattened)
h = Embedding(self.N_CLASSES, np.prod(self.SHAPE), input_length=1)(label)
label_embedding = Flatten()(h)
# parse out the image
img = Flatten()(image)
# unified model
h = multiply([img, label_embedding])
h = Dense(512)(h)
h = LeakyReLU(alpha=0.2)(h)
h = Dense(512)(h)
h = LeakyReLU(alpha=0.2)(h)
h = Dropout(0.4)(h)
h = Dense(512)(h)
h = LeakyReLU(alpha=0.2)(h)
h = Dropout(0.4)(h)
o = Dense(1, activation='sigmoid')(h)
model = Model(inputs=[image, label], outputs=[o])
model.summary()
return model
def stacked_G_D(self):
noise = Input((self.LATENT_DIM,)) # noise input
label = Input((1,)) # conditional input
img = self.G([noise, label])
valid = self.D([img, label])
model = Model(inputs=[noise, label], outputs=[valid])
model.summary()
return model
def train(self, X_train, Y_train, epochs=20000, batch=32, save_interval=100, toggle_D_trainable=None):
for i in range(epochs):
# train the discriminator
idx = np.random.randint(0, X_train.shape[0], batch)
imgs, labels = X_train[idx], Y_train[idx]
noise = np.random.normal(0, 1, (batch, self.LATENT_DIM))
fake_imgs = self.G.predict([noise, labels])
d_loss_real = self.D.train_on_batch([imgs, labels], np.ones((batch, 1)))
d_loss_fake = self.D.train_on_batch([fake_imgs, labels], np.zeros((batch, 1)))
d_loss = (np.add(d_loss_real, d_loss_fake)) * 0.5
# train the generator
sample_labels = np.random.randint(0, 10, batch).reshape(batch, 1)
g_loss = self.stacked_G_D.train_on_batch([noise, sample_labels], np.ones((batch, 1)))
if i % save_interval == 0:
print('epoch: {0} - disc loss: {1}, disc accuracy: {2}, gen loss: {2}'.format(i, d_loss[0], d_loss[1]*100, g_loss))
filename = 'mnist_{0}-{1}-{2}.png'.format(i, d_loss[0], g_loss)
self.plot_images(save_to_disk=True, filename=filename)
if i > 0 and toggle_D_trainable and i % toggle_D_trainable == 0:
self.D.trainable = False if self.D.trainable else True
self.D.compile(loss='binary_crossentropy', optimizer=self.OPTIMIZER, metrics=['accuracy'])
def plot_images(self, save_to_disk=False, n_images=10, filename=None, rows=2, size_scalar=4, class_arr=None):
if not filename: filename = 'mnist.png'
noise = np.random.normal(0, 1, (n_images, self.LATENT_DIM))
classes = class_arr if class_arr is not None else np.arange(0, n_images) % self.N_CLASSES
images = self.G.predict([noise, classes])
cols = np.ceil(n_images/rows) # n_cols in grid
fig = plt.figure(figsize=(cols*size_scalar, rows*size_scalar))
for i in range(n_images):
ax = fig.add_subplot(rows, np.ceil(n_images/rows), i+1)
image = np.reshape(images[i], [28, 28])
plt.imshow(image)
fig.subplots_adjust(hspace=0, wspace=0)
if save_to_disk:
fig.savefig(os.path.join('images', filename))
plt.close('all')
else:
fig.show()
(X_train, Y_train), (_, _) = mnist.load_data()
X_train = (X_train.astype(np.float32) - 127.5) / 127.5 # rescale {-1 to 1}
X_train = np.expand_dims(X_train, axis=3)
gan = GAN()
gan.train(X_train, Y_train, save_interval=100, toggle_D_trainable=1000)