type 'NoneType' is not iterable 在使用 ray tunes Trainable API 训练 pytorch 模型时出现错误

type 'NoneType' is not iterable error when training pytorch model with ray tunes Trainable API

我写了一个简单的 pytorch 脚本来训练 MNIST,它运行良好。我重新实现了我的脚本,使其与 Trainable class:

import numpy as np
import torch
import torch.optim as optim
import torch.nn as nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import torch.nn.functional as F
import ray

from ray import tune

# Change these values if you want the training to run quicker or slower.
EPOCH_SIZE = 512
TEST_SIZE = 256


class ConvNet(nn.Module):

    def __init__(self):
        super(ConvNet, self).__init__()
        # In this example, we don't change the model architecture
        # due to simplicity.
        self.conv1 = nn.Conv2d(1, 3, kernel_size=3)
        self.fc = nn.Linear(192, 10)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 3))
        x = x.view(-1, 192)
        x = self.fc(x)
        return F.log_softmax(x, dim=1)


class AlexTrainer(tune.Trainable):

    def setup(self, config):

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        # Data Setup
        mnist_transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])

        self.train_loader = DataLoader(
            datasets.MNIST("~/data", train=True, download=True, transform=mnist_transforms),
            batch_size=64,
            shuffle=True)
        self.test_loader = DataLoader(
            datasets.MNIST("~/data", train=False, transform=mnist_transforms),
            batch_size=64,
            shuffle=True)

        self.model = ConvNet()
        self.optimizer = optim.SGD(self.model.parameters(), lr=config["lr"], momentum=config["momentum"])

        print('finished setup')

    def step(self):

        self.train()
        print("after train")
        acc = self.test()

        return {'acc': acc}

    def train(self):

        print("in train")

        self.model.train()
        for batch_idx, (data, target) in enumerate(self.train_loader):

            # We set this just for the example to run quickly.
            if batch_idx * len(data) > EPOCH_SIZE:
                return

            data, target = data.to(self.device), target.to(self.device)
            self.optimizer.zero_grad()

            print(type(data))

            output = self.model(data)
            loss = F.nll_loss(output, target)
            loss.backward()

            self.optimizer.step()

    def test(self):
        self.model.eval()

        correct = 0
        total = 0
        with torch.no_grad():
            for batch_idx, (data, target) in enumerate(self.test_loader):
                # We set this just for the example to run quickly.
                if batch_idx * len(data) > TEST_SIZE:
                    break
                data, target = data.to(self.device), target.to(self.device)
                outputs = self.model(data)
                _, predicted = torch.max(outputs.data, 1)
                total += target.size(0)
                correct += (predicted == target).sum().item()

        return correct / total


if __name__ == '__main__':
    ray.init()
    analysis = tune.run(
        AlexTrainer,
        stop={"training_iteration": 2},
        # verbose=1,
        config={
            "lr": tune.sample_from(lambda spec: 10 ** (-10 * np.random.rand())),
            "momentum": tune.uniform(0.1, 0.9)
        }
    )

我怎么尝试 运行,这次失败了:

Traceback (most recent call last):
  File "/hdd/raytune/venv/lib/python3.6/site-packages/ray/tune/trial_runner.py", line 473, in _process_trial
    is_duplicate = RESULT_DUPLICATE in result
TypeError: argument of type 'NoneType' is not iterable
Traceback (most recent call last):
  File "/hdd/raytune/test_3.py", line 116, in <module>
    "momentum": tune.uniform(0.1, 0.9)
  File "/hdd/raytune/venv/lib/python3.6/site-packages/ray/tune/tune.py", line 356, in run
    raise TuneError("Trials did not complete", incomplete_trials)
ray.tune.error.TuneError: ('Trials did not complete', [AlexTrainer_9b3cd_00000])

这可能是什么原因?

这是因为您实际上覆盖了 Trainable 中现有的 train 方法。如果您将 train 方法重命名为其他名称,它应该会按预期工作。