自定义采样器在 Pytorch 中的正确使用
Custom Sampler correct use in Pytorch
我有一个地图类型的数据集,用于实例分割任务。
数据集非常不平衡,有些图像只有 10 个对象,而其他图像有多达 1200 个对象。
如何限制每批对象的数量?
一个最小的可重现示例是:
import math
import torch
import random
import numpy as np
import pandas as pd
from torch.utils.data import Dataset
from torch.utils.data.sampler import BatchSampler
np.random.seed(0)
random.seed(0)
torch.manual_seed(0)
W = 700
H = 1000
def collate_fn(batch) -> tuple:
return tuple(zip(*batch))
class SyntheticDataset(Dataset):
def __init__(self, image_ids):
self.image_ids = torch.tensor(image_ids, dtype=torch.int64)
self.num_classes = 9
def __len__(self):
return len(self.image_ids)
def __getitem__(self, idx: int):
"""
returns single sample
"""
# print("idx: ", idx)
# deliberately left dangling
# id = self.image_ids[idx].item()
# image_id = self.image_ids[idx]
image_id = torch.as_tensor(idx)
image = torch.randint(0, 255, (H, W))
num_objects = random.randint(10, 1200)
image = torch.randint(0, 255, (3, H, W))
masks = torch.randint(0, 255, (num_objects, H, W))
target = {}
target["image_id"] = image_id
areas = torch.randint(100, 20000, (1, num_objects), dtype=torch.int64)
boxes = torch.randint(100, H * W, (num_objects, 4), dtype=torch.int64)
labels = torch.randint(1, self.num_classes, (1, num_objects), dtype=torch.int64)
iscrowd = torch.zeros(len(labels), dtype=torch.int64)
target["boxes"] = boxes
target["labels"] = labels
target["area"] = areas
target["iscrowd"] = iscrowd
target["masks"] = masks
return image, target, image_id
class BalancedObjectsSampler(BatchSampler):
"""Samples either batch_size images or batches num_objs_per_batch objects.
Args:
data_source (list): contains tuples of (img_id).
batch_size (int): batch size.
num_objs_per_batch (int): number of objects in a batch.
Return
yields the batch_ids/image_ids/image_indices
"""
def __init__(self, data_source, batch_size, num_objs_per_batch, drop_last=False):
self.data_source = data_source
self.sampler = data_source
self.batch_size = batch_size
self.drop_last = drop_last
self.num_objs_per_batch = num_objs_per_batch
self.batch_count = math.ceil(len(self.data_source) / self.batch_size)
def __iter__(self):
obj_count = 0
batch = []
batches = []
counter = 0
for i, (k, s) in enumerate(self.data_source.iteritems()):
if (
obj_count <= obj_count + s
and len(batch) <= self.batch_size - 1
and obj_count + s <= self.num_objs_per_batch
and i < len(self.data_source) - 1
):
# because of https://pytorch.org/docs/stable/data.html#data-loading-order-and-sampler
batch.append(i)
obj_count += s
else:
batches.append(batch)
yield batch
obj_count = 0
batch = []
counter += 1
obj_sums = {}
batch_size = 10
workers = 4
fake_image_ids = np.random.randint(1600000, 1700000, 100)
# assigning any in-range number objects count to each image
for i, k in enumerate(fake_image_ids):
obj_sums[k] = random.randint(10, 1200)
obj_counts = pd.Series(obj_sums)
train_dataset = SyntheticDataset(image_ids=fake_image_ids)
balanced_sampler = BalancedObjectsSampler(
data_source=obj_counts,
batch_size=batch_size,
num_objs_per_batch=1500,
drop_last=False,
)
data_loader_sampler = torch.utils.data.DataLoader(
train_dataset,
num_workers=workers,
collate_fn=collate_fn,
sampler=balanced_sampler,
)
data_loader_iter = torch.utils.data.DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=workers,
collate_fn=collate_fn,
)
遍历 balanced_sampler
for i, bal_batch in enumerate(balanced_sampler):
print(f"batch_{i}: ", bal_batch)
产量
batch_0: [0]
batch_1: [2, 3]
batch_2: [5]
batch_3: [7]
batch_4: [9, 10]
batch_5: [12, 13, 14, 15]
batch_6: [17, 18]
batch_7: [20, 21, 22]
batch_8: [24, 25]
batch_9: [27]
batch_10: [29]
batch_11: [31]
batch_12: [33]
batch_13: [35, 36, 37]
batch_14: [39, 40]
batch_15: [42, 43]
batch_16: [45, 46]
batch_17: [48, 49, 50]
batch_18: [52, 53, 54]
batch_19: [56]
batch_20: [58, 59]
batch_21: [61, 62]
batch_22: [64]
batch_23: [66]
batch_24: [68]
batch_25: [70, 71]
batch_26: [73]
batch_27: [75, 76, 77]
batch_28: [79, 80]
batch_29: [82, 83, 84, 85, 86, 87]
batch_30: [89]
batch_31: [91]
batch_32: [93, 94]
batch_33: [96]
batch_34: [98]
以上显示的值是图片的索引,但也可以是批次索引甚至是图片的id。
来自 运行
for i, batch in enumerate(data_loader_sampler):
print("__sample__: ", i, len(batch[0]))
看到该批次包含单个样本,而不是预期的数量。
__sample__: 0 1
__sample__: 1 1
__sample__: 2 1
__sample__: 3 1
__sample__: 4 1
__sample__: 5 1
__sample__: 6 1
__sample__: 7 1
__sample__: 8 1
__sample__: 9 1
__sample__: 10 1
__sample__: 11 1
__sample__: 12 1
__sample__: 13 1
__sample__: 14 1
__sample__: 15 1
__sample__: 16 1
__sample__: 17 1
__sample__: 18 1
__sample__: 19 1
__sample__: 20 1
__sample__: 21 1
__sample__: 22 1
__sample__: 23 1
__sample__: 24 1
__sample__: 25 1
__sample__: 26 1
__sample__: 27 1
__sample__: 28 1
__sample__: 29 1
__sample__: 30 1
__sample__: 31 1
__sample__: 32 1
__sample__: 33 1
__sample__: 34 1
我真正想要防止的是以下由
引起的行为
for i, batch in enumerate(data_loader_iter):
print("__iter__: ", i, sum([k["masks"].shape[0] for k in batch[1]]))
也就是
__iter__: 0 2510
__iter__: 1 2060
__iter__: 2 2203
__iter__: 3 2815
ERROR: Unexpected bus error encountered in worker. This might be caused by insufficient shared memory (shm).
Traceback (most recent call last):
File "/usr/lib/python3.8/multiprocessing/queues.py", line 239, in _feed
obj = _ForkingPickler.dumps(obj)
File "/usr/lib/python3.8/multiprocessing/reduction.py", line 51, in dumps
cls(buf, protocol).dump(obj)
File "/blip/venv/lib/python3.8/site-packages/torch/multiprocessing/reductions.py", line 328, in reduce_storage
fd, size = storage._share_fd_()
RuntimeError: falseINTERNAL ASSERT FAILED at "../aten/src/ATen/MapAllocator.cpp":300, please report a bug to PyTorch. unable to write to file </torch_431207_56>
Traceback (most recent call last):
File "/blip/venv/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 990, in _try_get_data
data = self._data_queue.get(timeout=timeout)
File "/usr/lib/python3.8/multiprocessing/queues.py", line 107, in get
if not self._poll(timeout):
File "/usr/lib/python3.8/multiprocessing/connection.py", line 257, in poll
return self._poll(timeout)
File "/usr/lib/python3.8/multiprocessing/connection.py", line 424, in _poll
r = wait([self], timeout)
File "/usr/lib/python3.8/multiprocessing/connection.py", line 931, in wait
ready = selector.select(timeout)
File "/usr/lib/python3.8/selectors.py", line 415, in select
fd_event_list = self._selector.poll(timeout)
File "/blip/venv/lib/python3.8/site-packages/torch/utils/data/_utils/signal_handling.py", line 66, in handler
_error_if_any_worker_fails()
RuntimeError: DataLoader worker (pid 431257) is killed by signal: Bus error. It is possible that dataloader's workers are out of shared memory. Please try to raise your shared memory limit.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "so.py", line 170, in <module>
for i, batch in enumerate(data_loader_iter):
File "/blip/venv/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 521, in __next__
data = self._next_data()
File "/blip/venv/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1186, in _next_data
idx, data = self._get_data()
File "/blip/venv/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1152, in _get_data
success, data = self._try_get_data()
File "/blip/venv/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1003, in _try_get_data
raise RuntimeError('DataLoader worker (pid(s) {}) exited unexpectedly'.format(pids_str)) from e
RuntimeError: DataLoader worker (pid(s) 431257) exited unexpectedly
当每批次的对象数量大于 ~2500 时,这种情况总是会发生。
一个直接的解决方法是将 batch_size
设置得较低,我只需要一个更优化的解决方案。
如果您真正要解决的问题是:
ERROR: Unexpected bus error encountered in worker. This might be caused by insufficient shared memory (shm).
您可以尝试使用
调整分配的共享内存的大小
# mount -o remount,size=<whatever_is_enough>G /dev/shm
但是,这并不总是可行的,解决您的问题的一种方法是
class SyntheticDataset(Dataset):
def __init__(self, image_ids):
self.image_ids = torch.tensor(image_ids, dtype=torch.int64)
self.num_classes = 9
def __len__(self):
return len(self.image_ids)
def __getitem__(self, indices):
worker_info = torch.utils.data.get_worker_info()
batch = []
for i in indices:
sample = self.get_sample(i)
batch.append(sample)
gc.collect()
return batch
def get_sample(self, idx: int):
image_id = torch.as_tensor(idx)
image = torch.randint(0, 255, (H, W))
num_objects = idx
image = torch.randint(0, 255, (3, H, W))
masks = torch.randint(0, 255, (num_objects, H, W))
target = {}
target["image_id"] = image_id
areas = torch.randint(100, 20000, (1, num_objects), dtype=torch.int64)
boxes = torch.randint(100, H * W, (num_objects, 4), dtype=torch.int64)
labels = torch.randint(1, self.num_classes, (1, num_objects), dtype=torch.int64)
iscrowd = torch.zeros(len(labels), dtype=torch.int64)
target["boxes"] = boxes
target["labels"] = labels
target["area"] = areas
target["iscrowd"] = iscrowd
target["masks"] = masks
return image, target, image_id
和
class BalancedObjectsSampler(BatchSampler):
"""Samples either batch_size images or batches num_objs_per_batch objects.
Args:
data_source (list): contains tuples of (img_id).
batch_size (int): batch size.
num_objs_per_batch (int): number of objects in a batch.
Return
yields the batch_ids/image_ids/image_indices
"""
def __init__(self, data_source, batch_size, num_objs_per_batch, drop_last=False):
self.data_source = data_source
self.sampler = data_source
self.batch_size = batch_size
self.drop_last = drop_last
self.num_objs_per_batch = num_objs_per_batch
self.batch_count = math.ceil(len(self.data_source) / self.batch_size)
obj_count = 0
batch = []
batches = []
batches_sums = []
for i, (k, s) in enumerate(self.data_source.iteritems()):
if (
len(batch) < self.batch_size
and obj_count + s < self.num_objs_per_batch
and i < len(self.data_source) - 1
):
batch.append(s)
obj_count += s
else:
batches.append(len(batch))
batches_sums.append(obj_count)
obj_count = 0
batch = []
self.batches = batches
self.batch_count = len(batches)
def __iter__(self):
batch = []
img_counts_id = 0
for idx, (k, s) in enumerate(self.data_source.iteritems()):
if len(batch) < self.batches[img_counts_id] and idx < len(self.data_source):
batch.append(s)
elif len(batch) == self.batches[img_counts_id]:
gc.collect()
yield batch
batch = []
if img_counts_id < self.batch_count - 1:
img_counts_id += 1
else:
break
if len(batch) > 0 and not self.drop_last:
yield batch
def __len__(self) -> int:
if self.drop_last:
return len(self.data_source) // self.batch_size
else:
return (len(self.data_source) + self.batch_size - 1) // self.batch_size
由于 SyntheticDataset 的 __getitem__
正在接收一个索引列表,最简单的解决方案就是遍历索引并检索样本列表。您可能只需要以不同的方式整理输出,以便将其提供给您的模型。
对于 BalancedObjectsSampler,我计算了 __init__
中每个批次的大小,并在 __iter__
到 assemble 个批次中使用了它。
注意:如果您的 num_workers > 0
试图将最多 1500 个对象打包到一个批次中,这仍然会失败 - 通常一个工作人员一次加载一个批次。因此,在考虑使用多处理时,您必须 re-assess 您的 num_objs_per_batch
。
我有一个地图类型的数据集,用于实例分割任务。 数据集非常不平衡,有些图像只有 10 个对象,而其他图像有多达 1200 个对象。
如何限制每批对象的数量?
一个最小的可重现示例是:
import math
import torch
import random
import numpy as np
import pandas as pd
from torch.utils.data import Dataset
from torch.utils.data.sampler import BatchSampler
np.random.seed(0)
random.seed(0)
torch.manual_seed(0)
W = 700
H = 1000
def collate_fn(batch) -> tuple:
return tuple(zip(*batch))
class SyntheticDataset(Dataset):
def __init__(self, image_ids):
self.image_ids = torch.tensor(image_ids, dtype=torch.int64)
self.num_classes = 9
def __len__(self):
return len(self.image_ids)
def __getitem__(self, idx: int):
"""
returns single sample
"""
# print("idx: ", idx)
# deliberately left dangling
# id = self.image_ids[idx].item()
# image_id = self.image_ids[idx]
image_id = torch.as_tensor(idx)
image = torch.randint(0, 255, (H, W))
num_objects = random.randint(10, 1200)
image = torch.randint(0, 255, (3, H, W))
masks = torch.randint(0, 255, (num_objects, H, W))
target = {}
target["image_id"] = image_id
areas = torch.randint(100, 20000, (1, num_objects), dtype=torch.int64)
boxes = torch.randint(100, H * W, (num_objects, 4), dtype=torch.int64)
labels = torch.randint(1, self.num_classes, (1, num_objects), dtype=torch.int64)
iscrowd = torch.zeros(len(labels), dtype=torch.int64)
target["boxes"] = boxes
target["labels"] = labels
target["area"] = areas
target["iscrowd"] = iscrowd
target["masks"] = masks
return image, target, image_id
class BalancedObjectsSampler(BatchSampler):
"""Samples either batch_size images or batches num_objs_per_batch objects.
Args:
data_source (list): contains tuples of (img_id).
batch_size (int): batch size.
num_objs_per_batch (int): number of objects in a batch.
Return
yields the batch_ids/image_ids/image_indices
"""
def __init__(self, data_source, batch_size, num_objs_per_batch, drop_last=False):
self.data_source = data_source
self.sampler = data_source
self.batch_size = batch_size
self.drop_last = drop_last
self.num_objs_per_batch = num_objs_per_batch
self.batch_count = math.ceil(len(self.data_source) / self.batch_size)
def __iter__(self):
obj_count = 0
batch = []
batches = []
counter = 0
for i, (k, s) in enumerate(self.data_source.iteritems()):
if (
obj_count <= obj_count + s
and len(batch) <= self.batch_size - 1
and obj_count + s <= self.num_objs_per_batch
and i < len(self.data_source) - 1
):
# because of https://pytorch.org/docs/stable/data.html#data-loading-order-and-sampler
batch.append(i)
obj_count += s
else:
batches.append(batch)
yield batch
obj_count = 0
batch = []
counter += 1
obj_sums = {}
batch_size = 10
workers = 4
fake_image_ids = np.random.randint(1600000, 1700000, 100)
# assigning any in-range number objects count to each image
for i, k in enumerate(fake_image_ids):
obj_sums[k] = random.randint(10, 1200)
obj_counts = pd.Series(obj_sums)
train_dataset = SyntheticDataset(image_ids=fake_image_ids)
balanced_sampler = BalancedObjectsSampler(
data_source=obj_counts,
batch_size=batch_size,
num_objs_per_batch=1500,
drop_last=False,
)
data_loader_sampler = torch.utils.data.DataLoader(
train_dataset,
num_workers=workers,
collate_fn=collate_fn,
sampler=balanced_sampler,
)
data_loader_iter = torch.utils.data.DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=workers,
collate_fn=collate_fn,
)
遍历 balanced_sampler
for i, bal_batch in enumerate(balanced_sampler):
print(f"batch_{i}: ", bal_batch)
产量
batch_0: [0]
batch_1: [2, 3]
batch_2: [5]
batch_3: [7]
batch_4: [9, 10]
batch_5: [12, 13, 14, 15]
batch_6: [17, 18]
batch_7: [20, 21, 22]
batch_8: [24, 25]
batch_9: [27]
batch_10: [29]
batch_11: [31]
batch_12: [33]
batch_13: [35, 36, 37]
batch_14: [39, 40]
batch_15: [42, 43]
batch_16: [45, 46]
batch_17: [48, 49, 50]
batch_18: [52, 53, 54]
batch_19: [56]
batch_20: [58, 59]
batch_21: [61, 62]
batch_22: [64]
batch_23: [66]
batch_24: [68]
batch_25: [70, 71]
batch_26: [73]
batch_27: [75, 76, 77]
batch_28: [79, 80]
batch_29: [82, 83, 84, 85, 86, 87]
batch_30: [89]
batch_31: [91]
batch_32: [93, 94]
batch_33: [96]
batch_34: [98]
以上显示的值是图片的索引,但也可以是批次索引甚至是图片的id。
来自 运行
for i, batch in enumerate(data_loader_sampler):
print("__sample__: ", i, len(batch[0]))
看到该批次包含单个样本,而不是预期的数量。
__sample__: 0 1
__sample__: 1 1
__sample__: 2 1
__sample__: 3 1
__sample__: 4 1
__sample__: 5 1
__sample__: 6 1
__sample__: 7 1
__sample__: 8 1
__sample__: 9 1
__sample__: 10 1
__sample__: 11 1
__sample__: 12 1
__sample__: 13 1
__sample__: 14 1
__sample__: 15 1
__sample__: 16 1
__sample__: 17 1
__sample__: 18 1
__sample__: 19 1
__sample__: 20 1
__sample__: 21 1
__sample__: 22 1
__sample__: 23 1
__sample__: 24 1
__sample__: 25 1
__sample__: 26 1
__sample__: 27 1
__sample__: 28 1
__sample__: 29 1
__sample__: 30 1
__sample__: 31 1
__sample__: 32 1
__sample__: 33 1
__sample__: 34 1
我真正想要防止的是以下由
引起的行为for i, batch in enumerate(data_loader_iter):
print("__iter__: ", i, sum([k["masks"].shape[0] for k in batch[1]]))
也就是
__iter__: 0 2510
__iter__: 1 2060
__iter__: 2 2203
__iter__: 3 2815
ERROR: Unexpected bus error encountered in worker. This might be caused by insufficient shared memory (shm).
Traceback (most recent call last):
File "/usr/lib/python3.8/multiprocessing/queues.py", line 239, in _feed
obj = _ForkingPickler.dumps(obj)
File "/usr/lib/python3.8/multiprocessing/reduction.py", line 51, in dumps
cls(buf, protocol).dump(obj)
File "/blip/venv/lib/python3.8/site-packages/torch/multiprocessing/reductions.py", line 328, in reduce_storage
fd, size = storage._share_fd_()
RuntimeError: falseINTERNAL ASSERT FAILED at "../aten/src/ATen/MapAllocator.cpp":300, please report a bug to PyTorch. unable to write to file </torch_431207_56>
Traceback (most recent call last):
File "/blip/venv/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 990, in _try_get_data
data = self._data_queue.get(timeout=timeout)
File "/usr/lib/python3.8/multiprocessing/queues.py", line 107, in get
if not self._poll(timeout):
File "/usr/lib/python3.8/multiprocessing/connection.py", line 257, in poll
return self._poll(timeout)
File "/usr/lib/python3.8/multiprocessing/connection.py", line 424, in _poll
r = wait([self], timeout)
File "/usr/lib/python3.8/multiprocessing/connection.py", line 931, in wait
ready = selector.select(timeout)
File "/usr/lib/python3.8/selectors.py", line 415, in select
fd_event_list = self._selector.poll(timeout)
File "/blip/venv/lib/python3.8/site-packages/torch/utils/data/_utils/signal_handling.py", line 66, in handler
_error_if_any_worker_fails()
RuntimeError: DataLoader worker (pid 431257) is killed by signal: Bus error. It is possible that dataloader's workers are out of shared memory. Please try to raise your shared memory limit.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "so.py", line 170, in <module>
for i, batch in enumerate(data_loader_iter):
File "/blip/venv/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 521, in __next__
data = self._next_data()
File "/blip/venv/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1186, in _next_data
idx, data = self._get_data()
File "/blip/venv/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1152, in _get_data
success, data = self._try_get_data()
File "/blip/venv/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1003, in _try_get_data
raise RuntimeError('DataLoader worker (pid(s) {}) exited unexpectedly'.format(pids_str)) from e
RuntimeError: DataLoader worker (pid(s) 431257) exited unexpectedly
当每批次的对象数量大于 ~2500 时,这种情况总是会发生。
一个直接的解决方法是将 batch_size
设置得较低,我只需要一个更优化的解决方案。
如果您真正要解决的问题是:
ERROR: Unexpected bus error encountered in worker. This might be caused by insufficient shared memory (shm).
您可以尝试使用
调整分配的共享内存的大小# mount -o remount,size=<whatever_is_enough>G /dev/shm
但是,这并不总是可行的,解决您的问题的一种方法是
class SyntheticDataset(Dataset):
def __init__(self, image_ids):
self.image_ids = torch.tensor(image_ids, dtype=torch.int64)
self.num_classes = 9
def __len__(self):
return len(self.image_ids)
def __getitem__(self, indices):
worker_info = torch.utils.data.get_worker_info()
batch = []
for i in indices:
sample = self.get_sample(i)
batch.append(sample)
gc.collect()
return batch
def get_sample(self, idx: int):
image_id = torch.as_tensor(idx)
image = torch.randint(0, 255, (H, W))
num_objects = idx
image = torch.randint(0, 255, (3, H, W))
masks = torch.randint(0, 255, (num_objects, H, W))
target = {}
target["image_id"] = image_id
areas = torch.randint(100, 20000, (1, num_objects), dtype=torch.int64)
boxes = torch.randint(100, H * W, (num_objects, 4), dtype=torch.int64)
labels = torch.randint(1, self.num_classes, (1, num_objects), dtype=torch.int64)
iscrowd = torch.zeros(len(labels), dtype=torch.int64)
target["boxes"] = boxes
target["labels"] = labels
target["area"] = areas
target["iscrowd"] = iscrowd
target["masks"] = masks
return image, target, image_id
和
class BalancedObjectsSampler(BatchSampler):
"""Samples either batch_size images or batches num_objs_per_batch objects.
Args:
data_source (list): contains tuples of (img_id).
batch_size (int): batch size.
num_objs_per_batch (int): number of objects in a batch.
Return
yields the batch_ids/image_ids/image_indices
"""
def __init__(self, data_source, batch_size, num_objs_per_batch, drop_last=False):
self.data_source = data_source
self.sampler = data_source
self.batch_size = batch_size
self.drop_last = drop_last
self.num_objs_per_batch = num_objs_per_batch
self.batch_count = math.ceil(len(self.data_source) / self.batch_size)
obj_count = 0
batch = []
batches = []
batches_sums = []
for i, (k, s) in enumerate(self.data_source.iteritems()):
if (
len(batch) < self.batch_size
and obj_count + s < self.num_objs_per_batch
and i < len(self.data_source) - 1
):
batch.append(s)
obj_count += s
else:
batches.append(len(batch))
batches_sums.append(obj_count)
obj_count = 0
batch = []
self.batches = batches
self.batch_count = len(batches)
def __iter__(self):
batch = []
img_counts_id = 0
for idx, (k, s) in enumerate(self.data_source.iteritems()):
if len(batch) < self.batches[img_counts_id] and idx < len(self.data_source):
batch.append(s)
elif len(batch) == self.batches[img_counts_id]:
gc.collect()
yield batch
batch = []
if img_counts_id < self.batch_count - 1:
img_counts_id += 1
else:
break
if len(batch) > 0 and not self.drop_last:
yield batch
def __len__(self) -> int:
if self.drop_last:
return len(self.data_source) // self.batch_size
else:
return (len(self.data_source) + self.batch_size - 1) // self.batch_size
由于 SyntheticDataset 的 __getitem__
正在接收一个索引列表,最简单的解决方案就是遍历索引并检索样本列表。您可能只需要以不同的方式整理输出,以便将其提供给您的模型。
对于 BalancedObjectsSampler,我计算了 __init__
中每个批次的大小,并在 __iter__
到 assemble 个批次中使用了它。
注意:如果您的 num_workers > 0
试图将最多 1500 个对象打包到一个批次中,这仍然会失败 - 通常一个工作人员一次加载一个批次。因此,在考虑使用多处理时,您必须 re-assess 您的 num_objs_per_batch
。