如何在 Batch、PyTorch 上填充零
How to pad zeros on Batch, PyTorch
有更好的方法吗?如何在不创建新张量对象的情况下用零填充张量?我需要始终输入相同的 batchsize
,所以我想用零填充小于 batchsize
的输入。就像序列长度较短时 NLP 中的填充零,但这是批量填充。
目前,我创建了一个新的张量,但因此,我的 GPU 将耗尽内存。我不想将 batchsize 减半来处理这个操作。
import torch
from torch import nn
class MyModel(nn.Module):
def __init__(self, batchsize=16):
super().__init__()
self.batchsize = batchsize
def forward(self, x):
b, d = x.shape
print(x.shape) # torch.Size([7, 32])
if b != self.batchsize: # 2. I need batches to be of size 16, if batch isn't 16, I want to pad the rest to zero
new_x = torch.zeros(self.batchsize,d) # 3. so I create a new tensor, but this is bad as it increase the GPU memory required greatly
new_x[0:b,:] = x
x = new_x
b = self.batchsize
print(x.shape) # torch.Size([16, 32])
return x
model = MyModel()
x = torch.randn((7, 32)) # 1. shape's batch is 7, because this is last batch, and I dont want to "drop_last"
y = model(x)
print(y.shape)
您可以像这样填充额外的元素:
import torch.nn.functional as F
n = self.batchsize - b
new_x = F.pad(x, (0,0,n,0)) # pad the start of 2d tensors
new_x = F.pad(x, (0,0,0,n)) # pad the end of 2d tensors
new_x = F.pad(x, (0,0,0,0,0,n)) # pad the end of 3d tensors
有更好的方法吗?如何在不创建新张量对象的情况下用零填充张量?我需要始终输入相同的 batchsize
,所以我想用零填充小于 batchsize
的输入。就像序列长度较短时 NLP 中的填充零,但这是批量填充。
目前,我创建了一个新的张量,但因此,我的 GPU 将耗尽内存。我不想将 batchsize 减半来处理这个操作。
import torch
from torch import nn
class MyModel(nn.Module):
def __init__(self, batchsize=16):
super().__init__()
self.batchsize = batchsize
def forward(self, x):
b, d = x.shape
print(x.shape) # torch.Size([7, 32])
if b != self.batchsize: # 2. I need batches to be of size 16, if batch isn't 16, I want to pad the rest to zero
new_x = torch.zeros(self.batchsize,d) # 3. so I create a new tensor, but this is bad as it increase the GPU memory required greatly
new_x[0:b,:] = x
x = new_x
b = self.batchsize
print(x.shape) # torch.Size([16, 32])
return x
model = MyModel()
x = torch.randn((7, 32)) # 1. shape's batch is 7, because this is last batch, and I dont want to "drop_last"
y = model(x)
print(y.shape)
您可以像这样填充额外的元素:
import torch.nn.functional as F
n = self.batchsize - b
new_x = F.pad(x, (0,0,n,0)) # pad the start of 2d tensors
new_x = F.pad(x, (0,0,0,n)) # pad the end of 2d tensors
new_x = F.pad(x, (0,0,0,0,0,n)) # pad the end of 3d tensors