在 pytorch 中实现 x=T if abs(x)>T 作为激活函数

Implement x=T if abs(x)>T as an activation function in pytorch

我想在pytorch中实现以下激活函数:

x = T if abs(x)>T else x

我可以用 torch.clamp(min=-T, max=T) 做一些接近的事情,但这并不是我想要的行为(这与上面 x>-T 的行为相同,但 return -T 表示 x<-T)。有什么火炬功能可以帮助我实现这个目标吗?

torch.where 正是这样做的:

x = torch.where(torch.abs(x) > T, T, x)