如果低于某个阈值则剪裁张量
Clipping a tensor if below certain treshold
我有一个暗淡的 n x 3
张量,包含 n
3d 向量。我想用 dim n x 3
计算一个新的张量。如果向量的范数下降到某个阈值以下,我想将其设置为零向量并获得一个新的张量,其中包含已更改向量的索引位置。
示例:tensor([1, 2, 3], [0, 1, 1], [4, 5, 6], ...) 将导致 tensor([1, 2, 3], [0, 0, 0], [4, 5, 6], ...) with tensor([1]) 如果阈值设置为 1.5.
如何在不使用循环的情况下实现这一目标?感谢您的帮助。
只需执行 a[vector_norm(a,dim=1) < thr] = 0
,其中 thr
是您的阈值。这是一个演示。
import torch
from torch.linalg import vector_norm
n = 10
a = torch.rand(10,3)
print('a before:',a)
thr = 1
ind = vector_norm(a,dim=1) < thr
a[ind] = 0
print('a after:',a)
print('list of indices',ind.nonzero())
例子的结果运行:
a before: tensor([[0.0708, 0.7559, 0.3974],
[0.2969, 0.0974, 0.8652],
[0.8074, 0.8180, 0.2432],
[0.9006, 0.2447, 0.1602],
[0.6289, 0.1976, 0.8543],
[0.2109, 0.7539, 0.6334],
[0.9100, 0.2514, 0.2314],
[0.6657, 0.1940, 0.6565],
[0.4577, 0.8439, 0.5681],
[0.5566, 0.9979, 0.1468]])
a after: tensor([[0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000],
[0.8074, 0.8180, 0.2432],
[0.0000, 0.0000, 0.0000],
[0.6289, 0.1976, 0.8543],
[0.2109, 0.7539, 0.6334],
[0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000],
[0.4577, 0.8439, 0.5681],
[0.5566, 0.9979, 0.1468]])
list of indices tensor([[0],
[1],
[3],
[6],
[7]])
我有一个暗淡的 n x 3
张量,包含 n
3d 向量。我想用 dim n x 3
计算一个新的张量。如果向量的范数下降到某个阈值以下,我想将其设置为零向量并获得一个新的张量,其中包含已更改向量的索引位置。
示例:tensor([1, 2, 3], [0, 1, 1], [4, 5, 6], ...) 将导致 tensor([1, 2, 3], [0, 0, 0], [4, 5, 6], ...) with tensor([1]) 如果阈值设置为 1.5.
如何在不使用循环的情况下实现这一目标?感谢您的帮助。
只需执行 a[vector_norm(a,dim=1) < thr] = 0
,其中 thr
是您的阈值。这是一个演示。
import torch
from torch.linalg import vector_norm
n = 10
a = torch.rand(10,3)
print('a before:',a)
thr = 1
ind = vector_norm(a,dim=1) < thr
a[ind] = 0
print('a after:',a)
print('list of indices',ind.nonzero())
例子的结果运行:
a before: tensor([[0.0708, 0.7559, 0.3974],
[0.2969, 0.0974, 0.8652],
[0.8074, 0.8180, 0.2432],
[0.9006, 0.2447, 0.1602],
[0.6289, 0.1976, 0.8543],
[0.2109, 0.7539, 0.6334],
[0.9100, 0.2514, 0.2314],
[0.6657, 0.1940, 0.6565],
[0.4577, 0.8439, 0.5681],
[0.5566, 0.9979, 0.1468]])
a after: tensor([[0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000],
[0.8074, 0.8180, 0.2432],
[0.0000, 0.0000, 0.0000],
[0.6289, 0.1976, 0.8543],
[0.2109, 0.7539, 0.6334],
[0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000],
[0.4577, 0.8439, 0.5681],
[0.5566, 0.9979, 0.1468]])
list of indices tensor([[0],
[1],
[3],
[6],
[7]])