如何在 python 中保持其他索引通用的同时实现数组中的 select 特定元素?

How to realize select specific elements in array while keeping other indices general in python?

对于 rank-4 数组,M[n0,n1,n2,n3]

我愿意 M[m0,m1,:,:]*(some operation) 所有 m0,m1

M[:,m1,:,m3]*(some operation) 对所有 m1,m3 进行一些操作

我能做到

import numpy as np
import time
n0=2
n1=2
n2=2
n3=2
M = np.zeros((n0,n1,n2,n3))
M2 = np.zeros((n0,n1,n2,n3))
M3 = np.zeros((n0,n1,n2,n3))

i = 0
for m0 in range(M.shape[0]):
    for m1 in range(M.shape[1]):  
        for m2 in range(M.shape[2]):  
            for m3 in range(M.shape[3]):  
                M[m0,m1,m2,m3] = i
                i = i + 1

input = 0


if input == 0:
    for m0 in range(M.shape[0]):
        for m1 in range(M.shape[1]):   
            M2[m0,m1,:,:] =  M[m0,m1,:,:]*(m0+m1)

            
elif input == 1:
    for m1 in range(M.shape[1]):
        for m3 in range(M.shape[3]):  
            M2[:,m1,:,m3] = M[:,m1,:,m3]*(m1-m3)

for m0 in range(M2.shape[0]):
    for m1 in range(M2.shape[1]):  
        for m2 in range(M2.shape[2]):  
            for m3 in range(M2.shape[3]):  
               # M2[m0,m1,m2,m3] = i
                print(m0, m1, m2, m3, 'M', M[m0,m1,m2,m3], 'M2', M2[m0,m1,m2,m3])

(使用 : 跳过几个循环似乎比显式循环遍历所有索引 m0 - m3 更快。这就是这个问题的动机:利用 :m0+m1m1-m3 以某种方式随机拾取)

样本输出是

0 0 0 0 M 0.0 M2 0.0
0 0 0 1 M 1.0 M2 0.0
0 0 1 0 M 2.0 M2 0.0
0 0 1 1 M 3.0 M2 0.0
0 1 0 0 M 4.0 M2 4.0
0 1 0 1 M 5.0 M2 5.0
0 1 1 0 M 6.0 M2 6.0
0 1 1 1 M 7.0 M2 7.0
1 0 0 0 M 8.0 M2 8.0
1 0 0 1 M 9.0 M2 9.0
1 0 1 0 M 10.0 M2 10.0
1 0 1 1 M 11.0 M2 11.0
1 1 0 0 M 12.0 M2 24.0
1 1 0 1 M 13.0 M2 26.0
1 1 1 0 M 14.0 M2 28.0
1 1 1 1 M 15.0 M2 30.0

我的问题是,有没有简单的方法可以通过输入目标数组元素的位置,例如0,1,M[m0,m1,:.:]来实现上面的代码(input = 0) 的代码; 1,3 M[:,m1,:,m3]对于上面代码的下半部分(input = 1)?并包括其他情况,例如 0,2; 0,3;1,2;2,3。本质上,改变索引标签的位置和 :.

我可以让 python 打印包含所有情况的代码,但我希望有更简单的东西

我不知道你会需要什么复杂的操作,但目前我认为你对numpy计算还不熟悉。为此,我重写了整个答案,以更简单快捷的方式完成你的三个要求:

# Yours
n0 = n1 = n2 = n3 = 2
M = np.zeros((n0, n1, n2, n3))
M2 = np.zeros((n0, n1, n2, n3))
M3 = np.zeros((n0, n1, n2, n3))

i = 0
for m0 in range(M.shape[0]):
    for m1 in range(M.shape[1]):  
        for m2 in range(M.shape[2]):  
            for m3 in range(M.shape[3]):  
                M[m0,m1,m2,m3] = i
                i = i + 1

for m0 in range(M.shape[0]):
    for m1 in range(M.shape[1]):   
        M2[m0,m1,:,:] =  M[m0,m1,:,:]*(m0+m1)


for m1 in range(M.shape[1]):
    for m3 in range(M.shape[3]):  
        M3[:,m1,:,m3] = M[:,m1,:,m3]*(m1-m3)


# Mine
N = np.arange(n0 * n1 * n2 * n3, dtype=float).reshape(n0, n1, n2, n3)
m0, m1, m2, m3 = np.indices(N.shape, sparse=True)
N2 = N * (m0 + m1)
N3 = N * (m1 - m3)

可以使用等号和方法ndarray.all来判断两个数组是否完全相等:

print((M == N).all())
print((M2 == N2).all())
print((M3 == N3).all())

大家可以自己试试,都是True。 如果你想知道原理,我可以试着写一下,但我现在没有太多时间。

还要提醒大家,在操作numpy数组的时候,尽量避免使用循环