swapaxes 以及它是如何实现的?

swapaxes and how it is implemented?

我想知道是否有人可以向我解释这段代码?

c = self.config

assert len(pair_act.shape) == 3
assert len(pair_mask.shape) == 2
assert c.orientation in ['per_row', 'per_column']

if c.orientation == 'per_column':
  pair_act = jnp.swapaxes(pair_act, -2, -3)
  pair_mask = jnp.swapaxes(pair_mask, -1, -2)

好像pair_act是3维数组,pair_mask是二维数组?数字-1、-2 和-3 是什么?对于3维数组,我最初的想法是数组为0,列为1,行为2。那么-号从何而来呢?任何数组示例将不胜感激。感谢您的帮助。

jax.numpy.swapaxes 的文档在这里:https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.swapaxes.html

swapaxes的作用本质上是调换提供的两个轴,从而产生不同形状的数组:

import jax.numpy as jnp

x = jnp.arange(24).reshape((2, 3, 4))
print(x.shape)
# (2, 3, 4)

y = jnp.swapaxes(x, 1, 2)
print(y.shape)
# (2, 4, 3)

作为 numpy 索引的标准,负数从末尾倒数;这里的索引指的是形状中的条目(长度为 3),因此 -2, -1 等同于 1, 2:

y = jnp.swapaxes(x, -2, -1)
print(y.shape)
# (2, 4, 3)

swapaxes 的结果等同于适当构造的 transpose 操作:

y2 = jnp.transpose(x, (0, 2, 1))
print((y == y2).all())
# True