如何将数据集修复为 return 所需的输出 (pytorch)
How do I fix the Dataset to return desired output (pytorch)
我正在尝试使用来自外部函数的信息来决定将哪些数据 return。在这里,我添加了一个简化的代码来演示这个问题。当我使用 num_workers = 0
时,我得到了想要的行为(3 个纪元后的输出是 18)。但是,当我增加 num_workers
的值时,每个纪元后的输出都是相同的。并且全局变量保持不变。
from torch.utils.data import Dataset, DataLoader
x = 6
def getx():
global x
x+=1
print("x: ", x)
return x
class MyDataset(Dataset):
def __init__(self):
pass
def __getitem__(self, index):
global x
x = getx()
return x
def __len__(self):
return 3
dataset = MyDataset()
loader = DataLoader(
dataset,
num_workers=0,
shuffle=False
)
for epoch in range(4):
for idx, data in enumerate(loader):
print('Epoch {}, idx {}, val: {}'.format(epoch, idx, data))
num_workers=0
时的最终输出符合预期为18。但是当num_workers>0
时,x保持不变(最终输出为6)。
如何使用 num_workers>0
(i.e.How 获得与 num_workers=0
类似的行为,以确保数据加载器的 __getitem__
函数更改全局变量 x
的价值)?
其原因在于 python 中多处理的基本性质。设置 num_workers
意味着您的 DataLoader
创造了那个数量的 sub-processes。每个 sub-process 实际上是一个单独的 python 实例,具有自己的全局状态,并且不知道其他进程中发生了什么。
在 python 的多处理中,一个典型的解决方案是使用 Manager
。但是,由于您的多处理是通过 DataLoader 提供的,因此您无法在其中进行处理。
幸运的是,还可以做点别的。 DataLoader
实际上依赖于 torch.multiprocessing,只要它们在共享内存中,它又允许在进程之间共享张量。
所以你可以做的是,简单地使用 x 作为共享张量。
from torch.utils.data import Dataset, DataLoader
import torch
x = torch.tensor([6])
x.share_memory_()
def getx():
global x
x+=1
print("x: ", x.item())
return x
class MyDataset(Dataset):
def __init__(self):
pass
def __getitem__(self, index):
global x
x = getx()
return x
def __len__(self):
return 3
dataset = MyDataset()
loader = DataLoader(
dataset,
num_workers=2,
shuffle=False
)
for epoch in range(4):
for idx, data in enumerate(loader):
print('Epoch {}, idx {}, val: {}'.format(epoch, idx, data))
输出:
x: 7
x: 8
x: 9
Epoch 0, idx 0, val: tensor([[7]])
Epoch 0, idx 1, val: tensor([[8]])
Epoch 0, idx 2, val: tensor([[9]])
x: 10
x: 11
x: 12
Epoch 1, idx 0, val: tensor([[10]])
Epoch 1, idx 1, val: tensor([[12]])
Epoch 1, idx 2, val: tensor([[12]])
x: 13
x: 14
x: 15
Epoch 2, idx 0, val: tensor([[13]])
Epoch 2, idx 1, val: tensor([[15]])
Epoch 2, idx 2, val: tensor([[14]])
x: 16
x: 17
x: 18
Epoch 3, idx 0, val: tensor([[16]])
Epoch 3, idx 1, val: tensor([[18]])
Epoch 3, idx 2, val: tensor([[17]])
虽然这可行,但并不完美。查看纪元 1,注意有 2 个 12,而不是 11 和 12。这意味着两个独立的进程在执行打印之前执行了行 x+=1
。这是不可避免的,因为并行进程正在共享内存上工作。
如果您熟悉操作系统概念,您可以进一步实现某种 semaphore 并使用额外的变量来根据需要控制对 x 的访问 - 但这超出了范围这个问题我就不多说了。
我正在尝试使用来自外部函数的信息来决定将哪些数据 return。在这里,我添加了一个简化的代码来演示这个问题。当我使用 num_workers = 0
时,我得到了想要的行为(3 个纪元后的输出是 18)。但是,当我增加 num_workers
的值时,每个纪元后的输出都是相同的。并且全局变量保持不变。
from torch.utils.data import Dataset, DataLoader
x = 6
def getx():
global x
x+=1
print("x: ", x)
return x
class MyDataset(Dataset):
def __init__(self):
pass
def __getitem__(self, index):
global x
x = getx()
return x
def __len__(self):
return 3
dataset = MyDataset()
loader = DataLoader(
dataset,
num_workers=0,
shuffle=False
)
for epoch in range(4):
for idx, data in enumerate(loader):
print('Epoch {}, idx {}, val: {}'.format(epoch, idx, data))
num_workers=0
时的最终输出符合预期为18。但是当num_workers>0
时,x保持不变(最终输出为6)。
如何使用 num_workers>0
(i.e.How 获得与 num_workers=0
类似的行为,以确保数据加载器的 __getitem__
函数更改全局变量 x
的价值)?
其原因在于 python 中多处理的基本性质。设置 num_workers
意味着您的 DataLoader
创造了那个数量的 sub-processes。每个 sub-process 实际上是一个单独的 python 实例,具有自己的全局状态,并且不知道其他进程中发生了什么。
在 python 的多处理中,一个典型的解决方案是使用 Manager
。但是,由于您的多处理是通过 DataLoader 提供的,因此您无法在其中进行处理。
幸运的是,还可以做点别的。 DataLoader
实际上依赖于 torch.multiprocessing,只要它们在共享内存中,它又允许在进程之间共享张量。
所以你可以做的是,简单地使用 x 作为共享张量。
from torch.utils.data import Dataset, DataLoader
import torch
x = torch.tensor([6])
x.share_memory_()
def getx():
global x
x+=1
print("x: ", x.item())
return x
class MyDataset(Dataset):
def __init__(self):
pass
def __getitem__(self, index):
global x
x = getx()
return x
def __len__(self):
return 3
dataset = MyDataset()
loader = DataLoader(
dataset,
num_workers=2,
shuffle=False
)
for epoch in range(4):
for idx, data in enumerate(loader):
print('Epoch {}, idx {}, val: {}'.format(epoch, idx, data))
输出:
x: 7
x: 8
x: 9
Epoch 0, idx 0, val: tensor([[7]])
Epoch 0, idx 1, val: tensor([[8]])
Epoch 0, idx 2, val: tensor([[9]])
x: 10
x: 11
x: 12
Epoch 1, idx 0, val: tensor([[10]])
Epoch 1, idx 1, val: tensor([[12]])
Epoch 1, idx 2, val: tensor([[12]])
x: 13
x: 14
x: 15
Epoch 2, idx 0, val: tensor([[13]])
Epoch 2, idx 1, val: tensor([[15]])
Epoch 2, idx 2, val: tensor([[14]])
x: 16
x: 17
x: 18
Epoch 3, idx 0, val: tensor([[16]])
Epoch 3, idx 1, val: tensor([[18]])
Epoch 3, idx 2, val: tensor([[17]])
虽然这可行,但并不完美。查看纪元 1,注意有 2 个 12,而不是 11 和 12。这意味着两个独立的进程在执行打印之前执行了行 x+=1
。这是不可避免的,因为并行进程正在共享内存上工作。
如果您熟悉操作系统概念,您可以进一步实现某种 semaphore 并使用额外的变量来根据需要控制对 x 的访问 - 但这超出了范围这个问题我就不多说了。