训练 XLNET 模型时出现错误

Errors appear when training an XLNET model

我正尝试按如下方式训练 XLNET 模型。我想自己设置超参数而不使用任何预训练模型。

from transformers import XLNetConfig, XLNetModel
from transformers import Trainer, TrainingArguments
# Initializing an XLNet configuration
configuration = XLNetConfig(use_mems_train = True)
model = XLNetModel(configuration)
train_dataset = 'sentences.txt'
training_args = TrainingArguments(
    output_dir='./results',          # output directory
    num_train_epochs=3,              # total # of training epochs
    per_device_train_batch_size=16,  # batch size per device during training
    warmup_steps=500,                # number of warmup steps for learning rate scheduler
    weight_decay=0.01,               # strength of weight decay
    logging_dir='./logs',            # directory for storing logs
)
trainer = Trainer(
    model=model,                         # the instantiated Transformers model to be trained
    args=training_args,                  # training arguments, defined above
    train_dataset=train_dataset,         # training dataset
)
trainer.train()

但是出现如下错误:

C:\ProgramData\Anaconda3\lib\site-packages\tensorflow\python\framework\dtypes.py:526: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
C:\ProgramData\Anaconda3\lib\site-packages\tensorflow\python\framework\dtypes.py:527: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
C:\ProgramData\Anaconda3\lib\site-packages\tensorflow\python\framework\dtypes.py:528: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
C:\ProgramData\Anaconda3\lib\site-packages\tensorflow\python\framework\dtypes.py:529: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
C:\ProgramData\Anaconda3\lib\site-packages\tensorflow\python\framework\dtypes.py:530: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
C:\ProgramData\Anaconda3\lib\site-packages\tensorflow\python\framework\dtypes.py:535: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
  np_resource = np.dtype([("resource", np.ubyte, 1)])
  0%|          | 0/9 [00:00<?, ?it/s]Traceback (most recent call last):
  File "untitled1/dfgd.py", line 23, in <module>
    trainer.train()
  File "C:\Users\DSP\AppData\Roaming\Python\Python37\site-packages\transformers\trainer.py", line 925, in train
    for step, inputs in enumerate(epoch_iterator):
  File "C:\Users\DSP\AppData\Roaming\Python\Python37\site-packages\torch\utils\data\dataloader.py", line 435, in __next__
    data = self._next_data()
  File "C:\Users\DSP\AppData\Roaming\Python\Python37\site-packages\torch\utils\data\dataloader.py", line 475, in _next_data
    data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
  File "C:\Users\DSP\AppData\Roaming\Python\Python37\site-packages\torch\utils\data\_utils\fetch.py", line 47, in fetch
    return self.collate_fn(data)
  File "C:\Users\DSP\AppData\Roaming\Python\Python37\site-packages\transformers\data\data_collator.py", line 52, in default_data_collator
    features = [vars(f) for f in features]
  File "C:\Users\DSP\AppData\Roaming\Python\Python37\site-packages\transformers\data\data_collator.py", line 52, in <listcomp>
    features = [vars(f) for f in features]
TypeError: vars() argument must have __dict__ attribute
Exception ignored in: <function tqdm.__del__ at 0x0000014A17ABE828>
Traceback (most recent call last):
  File "C:\ProgramData\Anaconda3\lib\site-packages\tqdm\std.py", line 1039, in __del__
  File "C:\ProgramData\Anaconda3\lib\site-packages\tqdm\std.py", line 1223, in close
  File "C:\ProgramData\Anaconda3\lib\site-packages\tqdm\std.py", line 555, in _decr_instances
  File "C:\ProgramData\Anaconda3\lib\site-packages\tqdm\_monitor.py", line 51, in exit
  File "C:\ProgramData\Anaconda3\lib\threading.py", line 522, in set
  File "C:\ProgramData\Anaconda3\lib\threading.py", line 365, in notify_all
  File "C:\ProgramData\Anaconda3\lib\threading.py", line 348, in notify
TypeError: 'NoneType' object is not callable

我该如何处理这些错误?如何训练我的 XLNET 模型?

您应该使用 TFRecord 数据集而不是文本文件。