(Tensorflow) 在 model.fit() 期间卡在纪元 1

(Tensorflow) Stuck at Epoch 1 during model.fit()

我一直在尝试让 Tensorflow 2.8.0 与我的 Windows GPU (GeForce GTX 1650 Ti) 一起工作,即使它检测到我的 GPU,我制作的任何模型都会卡在 Epoch 1 当我尝试使用 fit 方法直到内核(我在 jupyter notebook 和 spyder 上试过)挂起并重新启动时无限期。

基于 Tensorflow 的 website,我已经下载了各自的 cuDNN 和 CUDA 版本,为此我通过 运行 各种命令进一步验证(连同 tensorflow 对我的 GPU 的检测) :

CUDA(应该是11.2)

(on command line)
nvcc --version
Build cuda_11.2.r11.2/compiler.29373293_0

(In python)
import tensorflow.python.platform.build_info as build
print(build.build_info['cuda_version'])
Output: '64_112'

cuDNN(应该是8.1)

import tensorflow.python.platform.build_info as build
print(build.build_info['cuda_version'])
Output: '64_8' # Looks like v8 but I've actually installed v8.1 (cuDNN v8.1.1 (Feburary 26th, 2021), for CUDA 11.0,11.1 and 11.2) so I think it's fine?

GPU 检查

tf.config.list_physical_devices('GPU')
Output: [PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]

tf.test.is_gpu_available()
Output: True

tf.test.gpu_device_name()
Output: This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX AVX2
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
Created device /device:GPU:0 with 2153 MB memory:  -> device: 0, name: NVIDIA GeForce GTX 1650 Ti, pci bus id: 0000:01:00.0, compute capability: 7.5

然后当我尝试拟合任何类型的模型时,它就无法按照我上面的描述进行操作。 令人惊讶 的是,即使它无法加载 Tensorflow's CNN Tutorial, the only time it ever works is if I run the chunk of code from this 中描述的代码。这段代码看起来与其他所有失败的代码块几乎一样。

有人可以帮我解决这个问题吗?在过去的几个小时里,我一直在用我遇到的每一块代码拼命测试 TensorFlow,唯一没有卡在 Epoch 1 的时间是上面的 link。

**(我也通过 os.environ['CUDA_VISIBLE_DEVICES'] = '-1' 在我的 CPU 上尝试了 运行,一切似乎都正常)

更新(解决方案)

post 中的建议似乎有所帮助 - 我从压缩的 cudnn bin 子文件夹 (cudnn-11.2-windows-x64 复制了以下文件-v8.1.1.33\cuda\bin) 到我的 cuda bin 文件夹 (C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.2\bin)

cudnn_adv_infer64_8.dll
cudnn_adv_train64_8.dll
cudnn_cnn_infer64_8.dll
cudnn_cnn_train64_8.dll
cudnn_ops_infer64_8.dll
cudnn_ops_train64_8.dll

我最初似乎将 copy all cudnn*.dll files 误解为仅复制 cudnn64_8.dll 文件,而不是复制上面列出的所有其他文件。