GAN训练结果 D loss: nan, acc.: 50% G loss: nan
GAN training result D loss: nan, acc.: 50% G loss: nan
我正在尝试实现一个 GAN 来生成网络流量 .csv 数据集(表格 GAN),我的训练结果继续显示 [D loss: nan, acc.: 50%] [G loss: nan]。我想这是因为我的数据集在预处理后有 NaN 值,所以我使用了代码“assert not np.any(np.isnan(x))”,我得到了下面的错误。我需要帮助...
---------------------------------------------------------------------------
AssertionError Traceback (most recent call last)
<ipython-input-30-5e46f39aeea1> in <module>
5 #Training the GAN model chosen: Vanilla GAN, CGAN, DCGAN, etc.
6 synthesizer = model_1(gan_args)
----> 7 synthesizer.train(dataset, train_args)
<ipython-input-26-65296d00d312> in train(self, data, train_arguments)
72 # Train Discriminator
73 # ---------------------
---> 74 batch_data = self.get_data_batch(data, self.batch_size)
75 noise = tf.random.normal((self.batch_size, self.noise_dim))
76
<ipython-input-26-65296d00d312> in get_data_batch(self, train, batch_size, seed)
56 train_ix = list(train_ix) + list(train_ix) # duplicate to cover ranges past the end
of the set
57 x = train.loc[train_ix[start_i: stop_i]].values
---> 58 assert not np.any(np.isnan(x))
59 return np.reshape(x, (batch_size, -1))
60
AssertionError: `
我终于明白了。在删除不需要的列后使用 .dropna(how='any', inplace = True) 它解决了问题。现在我的结果以 93.57% 的准确率生成。
我正在尝试实现一个 GAN 来生成网络流量 .csv 数据集(表格 GAN),我的训练结果继续显示 [D loss: nan, acc.: 50%] [G loss: nan]。我想这是因为我的数据集在预处理后有 NaN 值,所以我使用了代码“assert not np.any(np.isnan(x))”,我得到了下面的错误。我需要帮助...
---------------------------------------------------------------------------
AssertionError Traceback (most recent call last)
<ipython-input-30-5e46f39aeea1> in <module>
5 #Training the GAN model chosen: Vanilla GAN, CGAN, DCGAN, etc.
6 synthesizer = model_1(gan_args)
----> 7 synthesizer.train(dataset, train_args)
<ipython-input-26-65296d00d312> in train(self, data, train_arguments)
72 # Train Discriminator
73 # ---------------------
---> 74 batch_data = self.get_data_batch(data, self.batch_size)
75 noise = tf.random.normal((self.batch_size, self.noise_dim))
76
<ipython-input-26-65296d00d312> in get_data_batch(self, train, batch_size, seed)
56 train_ix = list(train_ix) + list(train_ix) # duplicate to cover ranges past the end
of the set
57 x = train.loc[train_ix[start_i: stop_i]].values
---> 58 assert not np.any(np.isnan(x))
59 return np.reshape(x, (batch_size, -1))
60
AssertionError: `
我终于明白了。在删除不需要的列后使用 .dropna(how='any', inplace = True) 它解决了问题。现在我的结果以 93.57% 的准确率生成。