如何从 Pytorch 张量中删除每一列填充零的列?
How to get rid of every column that are filled with zero from a Pytorch tensor?
我有一个 pytorch 张量 A
如下所示:
A =
tensor([[ 4, 3, 3, ..., 0, 0, 0],
[ 13, 4, 13, ..., 0, 0, 0],
[707, 707, 4, ..., 0, 0, 0],
...,
[ 7, 7, 7, ..., 0, 0, 0],
[ 0, 0, 0, ..., 0, 0, 0],
[195, 195, 195, ..., 0, 0, 0]], dtype=torch.int32)
我愿意:
- 找出所有条目都等于 0 的所有列
- 只删除所有条目都等于 0 的列
我可以想象做:
zero_list = []
for j in range(A.size()[1]):
if torch.sum(A[:,j]) == 0:
zero_list = zero_list.append(j)
识别其元素只有 0 的列
但我不确定如何从原始张量中删除这些填充为 0 的列。
如何根据索引号从pytorch张量中删除零列?
谢谢,
索引要保留的列比索引要删除的列更有意义。
valid_cols = []
for col_idx in range(A.size(1)):
if not torch.all(A[:, col_idx] == 0):
valid_cols.append(col_idx)
A = A[:, valid_cols]
或者更神秘一点
valid_cols = [col_idx for col_idx, col in enumerate(torch.split(A, 1, dim=1)) if not torch.all(col == 0)]
A = A[:, valid_cols]
Identify all the columns whose all of its entries are equal to 0
non_empty_mask = A.abs().sum(dim=0).bool()
这对每列的绝对值求和,然后将结果转换为布尔值,即如果总和为零,则为 False
,否则为 True
。
Delete only those columns that has all of their entries equal to 0
A[:,non_empty_mask]
这只是将掩码应用于原始张量,即它保留 non_empty_mask
为 True
的行。
我有一个 pytorch 张量 A
如下所示:
A =
tensor([[ 4, 3, 3, ..., 0, 0, 0],
[ 13, 4, 13, ..., 0, 0, 0],
[707, 707, 4, ..., 0, 0, 0],
...,
[ 7, 7, 7, ..., 0, 0, 0],
[ 0, 0, 0, ..., 0, 0, 0],
[195, 195, 195, ..., 0, 0, 0]], dtype=torch.int32)
我愿意:
- 找出所有条目都等于 0 的所有列
- 只删除所有条目都等于 0 的列
我可以想象做:
zero_list = []
for j in range(A.size()[1]):
if torch.sum(A[:,j]) == 0:
zero_list = zero_list.append(j)
识别其元素只有 0 的列 但我不确定如何从原始张量中删除这些填充为 0 的列。
如何根据索引号从pytorch张量中删除零列?
谢谢,
索引要保留的列比索引要删除的列更有意义。
valid_cols = []
for col_idx in range(A.size(1)):
if not torch.all(A[:, col_idx] == 0):
valid_cols.append(col_idx)
A = A[:, valid_cols]
或者更神秘一点
valid_cols = [col_idx for col_idx, col in enumerate(torch.split(A, 1, dim=1)) if not torch.all(col == 0)]
A = A[:, valid_cols]
Identify all the columns whose all of its entries are equal to 0
non_empty_mask = A.abs().sum(dim=0).bool()
这对每列的绝对值求和,然后将结果转换为布尔值,即如果总和为零,则为 False
,否则为 True
。
Delete only those columns that has all of their entries equal to 0
A[:,non_empty_mask]
这只是将掩码应用于原始张量,即它保留 non_empty_mask
为 True
的行。