如何在 Pytorch Lightning 微调之前测试模型?

How to test a model before fine-tuning in Pytorch Lightning?

在 Google Colab 上做事。

import torch
from torch.utils.data import DataLoader
from transformers import BertJapaneseTokenizer, BertForSequenceClassification
import pytorch_lightning as pl

dataset_for_loader = [
    {'data':torch.tensor([0,1]), 'labels':torch.tensor(0)},
    {'data':torch.tensor([2,3]), 'labels':torch.tensor(1)},
    {'data':torch.tensor([4,5]), 'labels':torch.tensor(2)},
    {'data':torch.tensor([6,7]), 'labels':torch.tensor(3)},
]
loader = DataLoader(dataset_for_loader, batch_size=2)

for idx, batch in enumerate(loader):
    print(f'# batch {idx}')
    print(batch)

category_list = [
    'dokujo-tsushin',
    'it-life-hack',
    'kaden-channel',
    'livedoor-homme',
    'movie-enter',
    'peachy',
    'smax',
    'sports-watch',
    'topic-news'
]

tokenizer = BertJapaneseTokenizer.from_pretrained(MODEL_NAME)

max_length = 128
dataset_for_loader = []
for label, category in enumerate(tqdm(category_list)):
    # file ./text has lots of articles, categorized by category
    # and they are just plain texts, whose content begins from forth line
    for file in glob.glob(f'./text/{category}/{category}*'):
        lines = open(file).read().splitlines()
        text = '\n'.join(lines[3:])
        encoding = tokenizer(
            text,
            max_length=max_length, 
            padding='max_length',
            truncation=True
        )
        encoding['labels'] = label
        encoding = { k: torch.tensor(v) for k, v in encoding.items() }
        dataset_for_loader.append(encoding)

SEED=lambda:0.0

# random.shuffle(dataset_for_loader) # ランダムにシャッフル
random.shuffle(dataset_for_loader,SEED)
n = len(dataset_for_loader)
n_train = int(0.6*n)
n_val = int(0.2*n)
dataset_train = dataset_for_loader[:n_train]
dataset_val = dataset_for_loader[n_train:n_train+n_val]
dataset_test = dataset_for_loader[n_train+n_val:]

dataloader_train = DataLoader(
    dataset_train, batch_size=32, shuffle=True    
) 
dataloader_val = DataLoader(dataset_val, batch_size=256)
dataloader_test = DataLoader(dataset_test, batch_size=256)

class BertForSequenceClassification_pl(pl.LightningModule):
    def __init__(self, model_name, num_labels, lr):
        super().__init__()
        self.save_hyperparameters()
        self.bert_sc = BertForSequenceClassification.from_pretrained(
            model_name,
            num_labels=num_labels
        )

    def training_step(self, batch, batch_idx):
        output = self.bert_sc(**batch)
        loss = output.loss
        self.log('train_loss', loss)
        return loss

    def validation_step(self, batch, batch_idx):
        output = self.bert_sc(**batch)
        val_loss = output.loss
        self.log('val_loss', val_loss)

    def test_step(self, batch, batch_idx):
        labels = batch.pop('labels')
        output = self.bert_sc(**batch)
        labels_predicted = output.logits.argmax(-1)
        num_correct = ( labels_predicted == labels ).sum().item()
        accuracy = num_correct/labels.size(0)
        self.log('accuracy', accuracy)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.hparams.lr)

checkpoint = pl.callbacks.ModelCheckpoint(
    monitor='val_loss',
    mode='min',
    save_top_k=1,
    save_weights_only=True,
    dirpath='model/',
)
trainer = pl.Trainer(
    gpus=1,
    max_epochs=10,
    callbacks = [checkpoint]
)

model = BertForSequenceClassification_pl(
    MODEL_NAME, num_labels=9, lr=1e-5
)

### (a) ###

# I think this is where I am doing fine-tuning
trainer.fit(model, dataloader_train, dataloader_val)

# this is to score after fine-tuning
test = trainer.test(test_dataloaders=dataloader_test)
print(f'Accuracy: {test[0]["accuracy"]:.2f}')

但是我不太确定如何在微调前做一个测试,以便比较微调前后的两个模型,以显示微调的效果。

将以下两行插入 ### (a) ###:

test = trainer.test(test_dataloaders=dataloader_test)
print(f'Accuracy: {test[0]["accuracy"]:.2f}')

我得到了这个结果:

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-13-c8b2c67f2d5c> in <module>()
      9 
     10 # 6-19
---> 11 test = trainer.test(test_dataloaders=dataloader_test)
     12 print(f'Accuracy: {test[0]["accuracy"]:.2f}')
     13 

/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/trainer.py in test(self, model, test_dataloaders, ckpt_path, verbose, datamodule)
    896         self.verbose_test = verbose
    897 
--> 898         self._set_running_stage(RunningStage.TESTING, model or self.lightning_module)
    899 
    900         # If you supply a datamodule you can't supply train_dataloader or val_dataloaders

/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/trainer.py in _set_running_stage(self, stage, model_ref)
    563         the trainer and the model
    564         """
--> 565         model_ref.running_stage = stage
    566         self._running_stage = stage
    567 

AttributeError: 'NoneType' object has no attribute 'running_stage'

我注意到 Trainer.fit() can take None as arguments other than model,所以我尝试了这个:

trainer.fit(model)
test=trainer.test(test_dataloaders=dataloader_test)
print(f'Accuracy: {test[0]["accuracy"]:.2f}')

结果:

MisconfigurationException: No `train_dataloader()` method defined. Lightning `Trainer` expects as minimum a `training_step()`, `train_dataloader()` and `configure_optimizers()` to be defined.

谢谢。

Trainer需要调用它的.fit()来设置很多东西然后只有你可以做.test()或其他方法。

你把 .fit() 放在 .test() 之前是对的,但是 fit 调用需要一个有效的调用。您必须向它提供 dataloader/datamodule。但是由于您不想在此 fit 调用中执行 training/validation,只需在 Trainer 构造时传递 limit_[train/val]_batches=0

trainer = Trainer(gpus=..., ..., limit_train_batches=0, limit_val_batches=0)
trainer.fit(model, dataloader_train, dataloader_val)
trainer.test(model, dataloader_test) # without fine-tuning

此处的 fit 调用只会为您设置并跳过 training/validation。然后进行测试。下次 运行 相同的代码但没有 limit_[train/val]_batches,这将为你做预训练

trainer = Trainer(gpus=..., ...)
trainer.fit(model, dataloader_train, dataloader_val)
trainer.test(model, dataloader_test) # with fine-tuning

澄清一下 .fit() 对除模型以外的所有对象采用 None:这不完全正确 - 您必须提供 DataLoader 或 DataModule。