swapaxes 以及它是如何实现的?
swapaxes and how it is implemented?
我想知道是否有人可以向我解释这段代码?
c = self.config
assert len(pair_act.shape) == 3
assert len(pair_mask.shape) == 2
assert c.orientation in ['per_row', 'per_column']
if c.orientation == 'per_column':
pair_act = jnp.swapaxes(pair_act, -2, -3)
pair_mask = jnp.swapaxes(pair_mask, -1, -2)
好像pair_act是3维数组,pair_mask是二维数组?数字-1、-2 和-3 是什么?对于3维数组,我最初的想法是数组为0,列为1,行为2。那么-号从何而来呢?任何数组示例将不胜感激。感谢您的帮助。
jax.numpy.swapaxes
的文档在这里:https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.swapaxes.html
swapaxes
的作用本质上是调换提供的两个轴,从而产生不同形状的数组:
import jax.numpy as jnp
x = jnp.arange(24).reshape((2, 3, 4))
print(x.shape)
# (2, 3, 4)
y = jnp.swapaxes(x, 1, 2)
print(y.shape)
# (2, 4, 3)
作为 numpy 索引的标准,负数从末尾倒数;这里的索引指的是形状中的条目(长度为 3),因此 -2, -1
等同于 1, 2
:
y = jnp.swapaxes(x, -2, -1)
print(y.shape)
# (2, 4, 3)
swapaxes
的结果等同于适当构造的 transpose
操作:
y2 = jnp.transpose(x, (0, 2, 1))
print((y == y2).all())
# True
我想知道是否有人可以向我解释这段代码?
c = self.config
assert len(pair_act.shape) == 3
assert len(pair_mask.shape) == 2
assert c.orientation in ['per_row', 'per_column']
if c.orientation == 'per_column':
pair_act = jnp.swapaxes(pair_act, -2, -3)
pair_mask = jnp.swapaxes(pair_mask, -1, -2)
好像pair_act是3维数组,pair_mask是二维数组?数字-1、-2 和-3 是什么?对于3维数组,我最初的想法是数组为0,列为1,行为2。那么-号从何而来呢?任何数组示例将不胜感激。感谢您的帮助。
jax.numpy.swapaxes
的文档在这里:https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.swapaxes.html
swapaxes
的作用本质上是调换提供的两个轴,从而产生不同形状的数组:
import jax.numpy as jnp
x = jnp.arange(24).reshape((2, 3, 4))
print(x.shape)
# (2, 3, 4)
y = jnp.swapaxes(x, 1, 2)
print(y.shape)
# (2, 4, 3)
作为 numpy 索引的标准,负数从末尾倒数;这里的索引指的是形状中的条目(长度为 3),因此 -2, -1
等同于 1, 2
:
y = jnp.swapaxes(x, -2, -1)
print(y.shape)
# (2, 4, 3)
swapaxes
的结果等同于适当构造的 transpose
操作:
y2 = jnp.transpose(x, (0, 2, 1))
print((y == y2).all())
# True