函数 numpy.dot()、@ 和矩阵-矩阵乘法的方法 .dot() 之间有什么区别?

What is difference between the function numpy.dot(), @, and method .dot() for matrix-matrix multiplication?

有什么区别吗?如果不是,按惯例首选什么? 性能好像差不多

a=np.random.rand(1000,1000)
b=np.random.rand(1000,1000)
%timeit a.dot(b)     #14.3 ms ± 374 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
%timeit np.dot(a,b)  #14.7 ms ± 315 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
%timeit a @ b        #15.1 ms ± 779 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

他们基本上都在做同样的事情。时间方面,根据Numpy的文档 here:

  • 如果a和b都是一维数组,就是向量的内积 (没有复杂的共轭)。

  • 如果a和b都是二维数组,就是矩阵乘法,但是 最好使用 matmula @ b

  • 如果a或b中有一个是0-D(标量),则等价于乘和 最好使用 numpy.multiply(a, b)a * b

  • 如果a是一个N维数组,b是一个一维数组,它是对的和积 ab 的最后一个轴。

它们几乎相同,只有少数例外。

a.dot(b)np.dot(a, b)完全一样。参见 numpy.dot and ndarray.dot

但是,查看 numpy.dot 的文档:

If both a and b are 2-D arrays, it is matrix multiplication, but using matmul or a @ b is preferred.

a @ b对应numpy.matmul(a, b)dotmatmul 区别如下:

matmul differs from dot in two important ways:

  • Multiplication by scalars is not allowed, use * instead.
  • Stacks of matrices are broadcast together as if the matrices were elements, respecting the signature (n,k),(k,m)->(n,m):
>>> a = np.ones([9, 5, 7, 4])
>>> c = np.ones([9, 5, 4, 3])
>>> np.dot(a, c).shape (9, 5, 7, 9, 5, 3)
>>> np.matmul(a, c).shape (9, 5, 7, 3)
>>> # n is 7, k is 4, m is 3

在我看来,最好的描述和解释是一个清楚的例子:

# How and when to use dot or matmul (@) ? # suppose all B values of dense nnet is 0 

inp=np.random.random((20,10,100,4)) # 4 inputs 100 data 10 different cases 20 different groups
nnet1=np.random.random((4,3)) # 4 inputs 3 outputs 
nnet2=np.random.random((3,5)) # 3 inputs 5 outputs 
nnet3=np.random.random((5,2)) # 5 inputs 2 outputs 

test1=inp@nnet1@nnet2@nnet3
test2=inp.dot(nnet1).dot(nnet2).dot(nnet3)
print(test1.shape)
print(test2.shape)
print(test1[5,3,7,1]) #6 th data 4th case second nnet output
print(test2[5,3,7,1]) #6 th data 4th case second nnet output



inp=np.random.random((20,10,100,4)) # 4 inputs 100 data 10 different cases 20 different groups
nnet1=np.random.random((9,4,3)) # 4 inputs 3 outputs  # 9 different networks
nnet2=np.random.random((9,3,5)) # 3 inputs 5 outputs  # 9 different networks
nnet3=np.random.random((9,5,2)) # 5 inputs 2 outputs  # 9 different networks 

test1=inp@nnet1[0]@nnet2[0]@nnet3[0]  # for network 0 
test2=inp.dot(nnet1@nnet2@nnet3)
print(test1.shape)
print(test2.shape)
print(test1[5,3,7,1]) #6 th data 4th case second nnet output
print(test2[5,3,7,0,1]) # 6 th data 4th case second nnet output for network 0 

输出:

(20, 10, 100, 2)
(20, 10, 100, 2)
2.502277900709035
2.502277900709035
(20, 10, 100, 2)
(20, 10, 100, 9, 2)
0.6919054739155295
0.6919054739155295

在这里您可以比较和对比每个元素以有条件地将它们联系起来,从而使您深入理解...