numpy apply_along_axis 多维数据计算
numpy apply_along_axis computation on multidimensional data
我正在将一段J语言代码翻译成Python,但是python的应用函数的方式对我来说似乎有点不清楚...
我目前有一个 (3, 3, 2) 矩阵 A 和一个 (3, 3) 矩阵 B。
我想将 A 中的每个矩阵除以 B 中的行:
A = np.arange(1,19).reshape(3,3,2)
array([[[ 1, 2],
[ 3, 4],
[ 5, 6]],
[[ 7, 8],
[ 9, 10],
[11, 12]],
[[13, 14],
[15, 16],
[17, 18]]])
B = np.arange(1,10).reshape(3,3)
array([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
就是这样的结果
1 2
1.5 2
1.66667 2
1.75 2
1.8 2
1.83333 2
1.85714 2
1.875 2
1.88889 2
对于结果的第一个矩阵,我想计算的方式如下:
1/1 2/1
3/2 4/2
5/3 6/3
我试过了
np.apply_along_axis(np.divide,1,A,B)
但是它说
operands could not be broadcast together with shapes (10,) (10,10,2)
有什么建议吗?
提前谢谢你=]
ps。 J 代码是
A %"2 1 B
这意味着 "divide each matrix("2) 从 A 到 B 的每一行 ("1)"
或者只是
A % B
如果尾随维度匹配或为一,则广播有效!所以我们基本上可以添加一个虚拟维度!
import numpy as np
A = np.arange(1,19).reshape(3,3,2)
B = np.arange(1,10).reshape(3,3)
B = B[...,np.newaxis] # This adds new dummy dimension in the end, B's new shape is (3,3,1)
A/B
array([[[1. , 2. ],
[1.5 , 2. ],
[1.66666667, 2. ]],
[[1.75 , 2. ],
[1.8 , 2. ],
[1.83333333, 2. ]],
[[1.85714286, 2. ],
[1.875 , 2. ],
[1.88888889, 2. ]]])
我正在将一段J语言代码翻译成Python,但是python的应用函数的方式对我来说似乎有点不清楚...
我目前有一个 (3, 3, 2) 矩阵 A 和一个 (3, 3) 矩阵 B。
我想将 A 中的每个矩阵除以 B 中的行:
A = np.arange(1,19).reshape(3,3,2)
array([[[ 1, 2],
[ 3, 4],
[ 5, 6]],
[[ 7, 8],
[ 9, 10],
[11, 12]],
[[13, 14],
[15, 16],
[17, 18]]])
B = np.arange(1,10).reshape(3,3)
array([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
就是这样的结果
1 2
1.5 2
1.66667 2
1.75 2
1.8 2
1.83333 2
1.85714 2
1.875 2
1.88889 2
对于结果的第一个矩阵,我想计算的方式如下:
1/1 2/1
3/2 4/2
5/3 6/3
我试过了
np.apply_along_axis(np.divide,1,A,B)
但是它说
operands could not be broadcast together with shapes (10,) (10,10,2)
有什么建议吗? 提前谢谢你=]
ps。 J 代码是
A %"2 1 B
这意味着 "divide each matrix("2) 从 A 到 B 的每一行 ("1)"
或者只是
A % B
如果尾随维度匹配或为一,则广播有效!所以我们基本上可以添加一个虚拟维度!
import numpy as np
A = np.arange(1,19).reshape(3,3,2)
B = np.arange(1,10).reshape(3,3)
B = B[...,np.newaxis] # This adds new dummy dimension in the end, B's new shape is (3,3,1)
A/B
array([[[1. , 2. ],
[1.5 , 2. ],
[1.66666667, 2. ]],
[[1.75 , 2. ],
[1.8 , 2. ],
[1.83333333, 2. ]],
[[1.85714286, 2. ],
[1.875 , 2. ],
[1.88888889, 2. ]]])