为什么 torch.get_num_threads 尽管设置为 NUM_THREADS =12 但返回 1

Why is torch.get_num_threads returning 1 despite setting it to NUM_THREADS =12

我对使用 torch 还很陌生,我正在尝试使用 torch 运行 python 脚本。区块代码如下:

torch.backends.cudnn.deterministic = False
torch.backends.cudnn.benchmark = True
# uniform thread number
torch.set_num_threads(NUM_THREADS)
print('THREADS: ',torch.get_num_threads())
assert NUM_THREADS == torch.get_num_threads(), torch.get_num_threads()# Code fails here

这是我的 .env 文件:

CPU_NUM_THREADS=12
OMP_NUM_THREADS=12
OPENMP_NUM_THREADS=12
OPENBLAS_NUM_THREADS=12
MKL_NUM_THREADS=12
VECLIB_MAXIMUM_THREADS=12
NUMEXPR_NUM_THREADS=12

当我尝试打印 NUM_THREADS 时,它输出 12 但当我打印 torch.get_num_threads 时,它输出 1。

我的系统信息:

3.1 GHz 6 核英特尔酷睿 i5 MacOS。

我想知道为什么 torch.get_num_threads() 输出的是 1 而不是 12,我该如何解决?

更新:我将手电筒版本从 1.5.0 更新到 1.7.1 并解决了问题。