无法从本地磁盘加载 BERT

Cannot load BERT from local disk

我正在尝试使用 Huggingface 转换器 api 加载本地下载的 M-BERT 模型,但它抛出异常。 我克隆了这个 repo:https://huggingface.co/bert-base-multilingual-cased

bert = TFBertModel.from_pretrained("input/bert-base-multilingual-cased")

目录结构为:

但是我收到这个错误:

Traceback (most recent call last):
  File "/usr/local/lib/python3.7/dist-packages/transformers/modeling_tf_utils.py", line 1277, in from_pretrained
    missing_keys, unexpected_keys = load_tf_weights(model, resolved_archive_file, load_weight_prefix)
  File "/usr/local/lib/python3.7/dist-packages/transformers/modeling_tf_utils.py", line 467, in load_tf_weights
    with h5py.File(resolved_archive_file, "r") as f:
  File "/usr/local/lib/python3.7/dist-packages/h5py/_hl/files.py", line 408, in __init__
    swmr=swmr)
  File "/usr/local/lib/python3.7/dist-packages/h5py/_hl/files.py", line 173, in make_fid
    fid = h5f.open(name, flags, fapl=fapl)
  File "h5py/_objects.pyx", line 54, in h5py._objects.with_phil.wrapper
  File "h5py/_objects.pyx", line 55, in h5py._objects.with_phil.wrapper
  File "h5py/h5f.pyx", line 88, in h5py.h5f.open
OSError: Unable to open file (file signature not found)

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "train.py", line 81, in <module>
    __main__()
  File "train.py", line 59, in __main__
    model = create_model(num_classes)
  File "/content/drive/My Drive/msc-project/code/model.py", line 26, in create_model
    bert = TFBertModel.from_pretrained("input/bert-base-multilingual-cased")
  File "/usr/local/lib/python3.7/dist-packages/transformers/modeling_tf_utils.py", line 1280, in from_pretrained
    "Unable to load weights from h5 file. "
OSError: Unable to load weights from h5 file. If you tried to load a TF 2.0 model from a PyTorch checkpoint, please set from_pt=True. 

我哪里错了? 需要帮忙! 提前致谢。

正如评论中已经指出的那样 - 您的 from_pretrained 参数应该是托管在 huggingface.co 上的模型的 ID 或本地路径:

A path to a directory containing model weights saved using save_pretrained(), e.g., ./my_model_directory/.

documentation

查看您的堆栈跟踪,您的代码似乎在 运行 内部:

/content/drive/My Drive/msc-project/code/model.py 所以除非你的模型是: /content/drive/My Drive/msc-project/code/input/bert-base-multilingual-cased/ 无法加载。

我还会将路径设置为类似于文档示例,即:

bert = TFBertModel.from_pretrained("./input/bert-base-multilingual-cased/")