trax tl.Relu 和 tl.ShiftRight 层嵌套在 Serial Combinator 中
trax tl.Relu and tl.ShiftRight layers are nested inside Serial Combinator
我正在尝试构建一个注意力模型,但 Relu 和 ShiftRight 层默认嵌套在 Serial Combinator 中。
这进一步给我带来了训练错误。
layer_block = tl.Serial(
tl.Relu(),
tl.LayerNorm(), )
x = np.array([[-2, -1, 0, 1, 2],
[-20, -10, 0, 10, 20]]).astype(np.float32)
layer_block.init(shapes.signature(x)) y = layer_block(x)
print(f'layer_block: {layer_block}')
输出
layer_block: Serial[
Serial[
Relu
]
LayerNorm
]
预期输出
layer_block: Serial[
Relu
LayerNorm
]
同样的问题出现在tl.ShiftRight()
以上代码摘自官方文档Example 5
提前致谢
我找不到上述问题的确切解决方案,但您可以使用 tl.Fn() 创建自定义函数并在其中添加 Relu and ShiftRight 函数代码。
def _zero_pad(x, pad, axis):
"""Helper for jnp.pad with 0s for single-axis case."""
pad_widths = [(0, 0)] * len(x.shape)
pad_widths[axis] = pad # Padding on axis.
return jnp.pad(x, pad_widths, mode='constant')
def f(x):
if mode == 'predict':
return x
padded = _zero_pad(x, (n_positions, 0), 1)
return padded[:, :-n_positions]
# set ShiftRight parameters as global
n_positions = 1
mode='train'
layer_block = tl.Serial(
tl.Fn('Relu', lambda x: jnp.where(x <= 0, jnp.zeros_like(x), x)),
tl.LayerNorm(),
tl.Fn(f'ShiftRight({n_positions})', f)
)
x = np.array([[-2, -1, 0, 1, 2],
[-20, -10, 0, 10, 20]]).astype(np.float32)
layer_block.init(shapes.signature(x))
y = layer_block(x)
print(f'layer_block: {layer_block}')
输出
layer_block: Serial[
Relu
LayerNorm
ShiftRight(1)
]
我正在尝试构建一个注意力模型,但 Relu 和 ShiftRight 层默认嵌套在 Serial Combinator 中。 这进一步给我带来了训练错误。
layer_block = tl.Serial(
tl.Relu(),
tl.LayerNorm(), )
x = np.array([[-2, -1, 0, 1, 2],
[-20, -10, 0, 10, 20]]).astype(np.float32)
layer_block.init(shapes.signature(x)) y = layer_block(x)
print(f'layer_block: {layer_block}')
输出
layer_block: Serial[
Serial[
Relu
]
LayerNorm
]
预期输出
layer_block: Serial[
Relu
LayerNorm
]
同样的问题出现在tl.ShiftRight()
以上代码摘自官方文档Example 5
提前致谢
我找不到上述问题的确切解决方案,但您可以使用 tl.Fn() 创建自定义函数并在其中添加 Relu and ShiftRight 函数代码。
def _zero_pad(x, pad, axis):
"""Helper for jnp.pad with 0s for single-axis case."""
pad_widths = [(0, 0)] * len(x.shape)
pad_widths[axis] = pad # Padding on axis.
return jnp.pad(x, pad_widths, mode='constant')
def f(x):
if mode == 'predict':
return x
padded = _zero_pad(x, (n_positions, 0), 1)
return padded[:, :-n_positions]
# set ShiftRight parameters as global
n_positions = 1
mode='train'
layer_block = tl.Serial(
tl.Fn('Relu', lambda x: jnp.where(x <= 0, jnp.zeros_like(x), x)),
tl.LayerNorm(),
tl.Fn(f'ShiftRight({n_positions})', f)
)
x = np.array([[-2, -1, 0, 1, 2],
[-20, -10, 0, 10, 20]]).astype(np.float32)
layer_block.init(shapes.signature(x))
y = layer_block(x)
print(f'layer_block: {layer_block}')
输出
layer_block: Serial[
Relu
LayerNorm
ShiftRight(1)
]