torch.squeeze 和 torch.unsqueeze 相当于手电筒 (arrayfire)

torch.squeeze and torch.unsqueeze equivalent in Flashlight (arrayfire)

我正在将 PyTorch 代码移植到 Flashlight 代码。 Pytorch 中 squeezeunsqueeze 的等效 Arrayfire 或 Flashlight 函数是什么?

processed_query = self.query_layer(query.unsqueeze(1))

energies = energies.squeeze(-1)

如何将其转换为 Arrayfire 代码? (或者,手电筒?)

您可以使用 af::moddims 函数执行此操作:

array a = randu(10, 1, 10, 10);
squeezed_a = moddims(a, 10, 10, 10);