如何提高 Python 脚本访问集群的计算时间

How to improve the computational time of a Python script having access to a cluster

我是 Python 的新手,所以对于您将在我的代码中看到的(大量)错误,我提前表示歉意。

我有一个非常简单的 Python 脚本,它给我带来了一些麻烦: 代码读取来自门模拟的输出文件,该文件的内容主要是 ASCII 格式的数值。文件中的数据按行组织,每行有 22 个不同的数字,总行数取决于模拟 运行ning 时间并且可能会很大(我目前最大的是 11国标)。这是文件的典型行作为参考:

      0       0  -1   0     0    -1    -1   4.99521423437116252053158e-01  7.910e-02  1.347e+01 -1.600e+01 -1.600e+01 -1.347e+01      22     1     0  -1   -1   -1 Compton NULL NULL

我的 Python 代码所做的是从该文件的每一行中读取数值并提取我感兴趣的数值;然后通过 for 循环分析一些提取的值,最后将一组这些值以 ASCII 格式写入输出文件。

代码运行良好,但是当我开始使用更长的模拟时间(因此输入文件更大)时,代码的计算时间变得太长,导致效率低下。 我的目标是尽可能减少计算时间。

这是我目前正在使用并导致速度变慢的部分代码:

#position of centres
PMTloc =np.array([[-270.00, -156.00], [-270.00, -104.00], [-270.00, -52.00], [-270.00,  0.00], [-270.00,    52.00], [-270.00,   104.00], [-270.00,  156.00], [-225.00,  -182.00], [-225.00, -130.00], [-225.00, -78.00], [-225.00,  -26.00], [-225.00,  26.00], [-225.00,   78.00], [-225.00,   130.00], [-225.00,  182.00], [-180.00,  -208.00], [-180.00, -156.00], [-180.00, -104.00], [-180.00, -52.00], [-180.00,  0.00], [-180.00,    52.00], [-180.00,   104.00], [-180.00,  156.00], [-180.00,  208.00], [-135.00,  -234.00], [-135.00, -182.00], [-135.00, -130.00], [-135.00, -78.00], [-135.00,  -26.00], [-135.00,  26.00], [-135.00,   78.00], [-135.00,   130.00], [-135.00,  182.00], [-135.00,  234.00], [-90.00,   -208.00], [-90.00,  -156.00], [-90.00,  -104.00], [-90.00,  -52.00], [-90.00,   0.00], [-90.00, 52.00], [-90.00,    104.00], [-90.00,   156.00], [-90.00,   208.00], [-45.00,   -234.00], [-45.00,  -182.00], [-45.00,  -130.00], [-45.00,  -78.00], [-45.00,   -26.00], [-45.00,   26.00], [-45.00,    78.00], [-45.00,    130.00], [-45.00,   182.00], [-45.00,   234.00], [0.00, -208.00], [0.00,    -156.00], [0.00,    -104.00], [0.00,    -52.00], [0.00, 0.00], [0.00,   52.00], [0.00,  104.00], [0.00, 156.00], [0.00, 208.00], [45.00,    -234.00], [45.00,   -182.00], [45.00,   -130.00], [45.00,   -78.00], [45.00,    -26.00], [45.00,    26.00], [45.00, 78.00], [45.00, 130.00], [45.00,    182.00], [45.00,    234.00], [90.00,    -208.00], [90.00,   -156.00], [90.00,   -104.00], [90.00,   -52.00], [90.00,    0.00], [90.00,  52.00], [90.00, 104.00], [90.00,    156.00], [90.00,    208.00], [135.00,   -234.00], [135.00,  -182.00], [135.00,  -130.00], [135.00,  -78.00], [135.00,   -26.00], [135.00,   26.00], [135.00,    78.00], [135.00,    130.00], [135.00,   182.00], [135.00,   234.00], [180.00,   -208.00], [180.00,  -156.00], [180.00,  -104.00], [180.00,  -52.00], [180.00,   0.00], [180.00, 52.00], [180.00,    104.00], [180.00,   156.00], [180.00,   208.00], [225.00,   -182.00], [225.00,  -130.00], [225.00,  -78.00], [225.00,   -26.00], [225.00,   26.00], [225.00,    78.00], [225.00,    130.00], [225.00,   182.00]])


#checking number of arguments passed
n = len(sys.argv)
if n<3:
    print("Error number of arguments passed is incorrect. Try again...")

print("argv[0]: {0}".format(argv[0]))
print("argv[1]: {0}".format(argv[1]))
input_file = argv[1] 
output_file = argv[2] 
hits = []
with open(input_file, 'r') as DataIn:
    for line in DataIn:
        hits += [line.split()]



totalPMT = 108
runID = [x[0] for x in hits] 
runid=[int(i) for i in runID]

eventID = [x[1] for x in hits]
eventid=[int(i) for i in eventID]


time = [x[7] for x in hits]
time_fl = [float(i) for i in time]

posX = [x[10] for x in hits]
posx = [float(i) for i in posX]

posY = [x[11] for x in hits]
posy = [float(i) for i in posY]

posZ = [x[12] for x in hits]
posz = [float(i) for i in posZ]

partID = [x[13] for x in hits]
partid = [int(i) for i in partID]


PMTs = [0]*108 #initial output signal
nscinti = 0 
nphoton = 0 
startscinti = 0 
zscinti = 0 
Etres = 100 #energy treshold
numb_22 = 0
numb_0 = 0
eventIDnow = [0]*len(eventid) 
eventidnow = [int(i) for i in eventIDnow]
runIDnow = [0]*len(runid)
runidnow = [int(i) for i in runIDnow]


for i in range(len(eventid)):

    if ((eventid[i] > eventidnow[i]) | (runid[i] > runidnow[i])):
    
        if ((nscinti>0) & (partid[i]==22) & (nphoton>Etres)):  #check if last recorded event should be written 
           
            with open(output_file, 'a+') as DataOut:
                DataOut.write("{0} {1} {2}\n".format(startscinti, zscinti, ' '.join(map(str, PMTs))))

     
            
        
        
    for l in range(len(eventidnow)): #updating values of eventidnow and runidnow
        eventidnow[l] =  eventid[i]
        runidnow[l] = runid[i]

    startscinti = time_fl[i]
    zscinti = posz[i]
    nphoton = 0
    nscinti += 1
    for j in range(len(PMTs)):
        PMTs[j] = 0

#checking where the event get recorded
for k in range(len(PMTs)):
    if ((k<=totalPMT) & ((posx[i]-PMTloc[k][0])*(posx[i]-PMTloc[k][0]) + (posy[i]-PMTloc[k][1])*(posy[i]-PMTloc[k][1])<=23*23)):

        PMTs[k] += 1        
        nphoton += 1
        break

我完全意识到这段代码远非完美和优化,所以我愿意接受任何关于如何改进它的建议。

根据我所做的测量,脚本需要大约 35 分钟才能完成对 224MB 输入文件的操作,对于大约 10GB 的输入文件,计算时间超过 100 小时。

我想补充一点,我可以访问一个集群,我可以在其中 运行 此代码,并且我最多可以使用 12-16 个内核 ,理想情况下我想利用这个核心来提高我的脚本的性能。通过做一些研究,我发现了多处理和并行化,但由于我缺乏经验,我不确定这些方法是否可以应用于我的案例,也不确定它们是否真的会减少计算时间。尽管如此,我还是无法正确实施它们。任何有关如何使用多处理或并行化的帮助将不胜感激。

无论如何,我愿意接受任何可能的改进我的代码的建议,如果有更好的方法来实现我的目标,它不必利用集群。

感谢大家抽出宝贵时间!

更新 19/08

感谢@DarrylG 的帮助,我能够改进脚本的计算时间,大约是我初始时间的 1/8,我对这个结果很满意。我会 post 在这里使用我现在使用的版本:

import sys
import numpy as np
import time
import numpy as np
import itertools
import sys
from sys import argv
import os
import time
from timeit import default_timer as timer
from datetime import timedelta

def int_or_float(s):
    ' Convert string from to int or float '
    try:
        return int(s)
    except ValueError:
        try:
            return float(s)
        except ValueError:
            return s

def process(hits):

    totalPMT = 108

    runid, eventid, time_fl, posx, posy, posz, partid = zip(*[(x[0], x[1], x[7], x[10], x[11], x[12], x[13]) for x in hits])
    
    runid=[int(i) for i in runid]
    eventid=[int(i) for i in eventid]
    time_fl = [float(i) for i in time_fl]
    posx = [float(i) for i in posx]
    posy = [float(i) for i in posy]
    posz = [float(i) for i in posz]
    partid = [int(i) for i in partid]
    
    

    PMTs = [0]*108 #initial output signal
    nscinti, nscinti, nphoton, startscinti, zscinti = [0]*5

    Etres = 100 #energy treshold
    numb_22, numb_0 = 0, 0
    
    eventIDnow = [0]*len(eventid) 
    runIDnow = [0]*len(runid)

    # Previous value of eventID and runID
    eventIDnow = 0
    runIDnow = 0
    
    results = []
    for i in range(len(eventid)):
        if (eventid[i] > eventIDnow) or (runid[i] > runIDnow):
            if ((nscinti>0) and (partid[i]==22) and (nphoton>Etres)):
                print("Writing output file")
            
                results.append("{0} {1} {2}\n".format(startscinti, zscinti, ' '.join(map(str, PMTs))))


            eventIDnow = eventid[i]
            runIDnow = runid[i]

            startscinti = time_fl[i]
            zscinti = posz[i]
            nphoton = 0
            nscinti += 1
            PMTs = [0]*len(PMTs)
       
        #checking where the event get recorded
        k = next((k for k in range(len(PMTs)) 
                    if ((k<=totalPMT) and 
                        ((posx[i]-PMTloc[k][0])*(posx[i]-PMTloc[k][0]) + 
                         (posy[i]-PMTloc[k][1])*(posy[i]-PMTloc[k][1])<=23*23))), 
                None)
        if k:
            PMTs[k] += 1        
            nphoton += 1
    
    return results
                
#position of centres
PMTloc =np.array([[-270.00, -156.00], [-270.00, -104.00], [-270.00, -52.00], [-270.00,  0.00], [-270.00,    52.00], [-270.00,   104.00], [-270.00,  156.00], [-225.00,  -182.00], [-225.00, -130.00], [-225.00, -78.00], [-225.00,  -26.00], [-225.00,  26.00], [-225.00,   78.00], [-225.00,   130.00], [-225.00,  182.00], [-180.00,  -208.00], [-180.00, -156.00], [-180.00, -104.00], [-180.00, -52.00], [-180.00,  0.00], [-180.00,    52.00], [-180.00,   104.00], [-180.00,  156.00], [-180.00,  208.00], [-135.00,  -234.00], [-135.00, -182.00], [-135.00, -130.00], [-135.00, -78.00], [-135.00,  -26.00], [-135.00,  26.00], [-135.00,   78.00], [-135.00,   130.00], [-135.00,  182.00], [-135.00,  234.00], [-90.00,   -208.00], [-90.00,  -156.00], [-90.00,  -104.00], [-90.00,  -52.00], [-90.00,   0.00], [-90.00, 52.00], [-90.00,    104.00], [-90.00,   156.00], [-90.00,   208.00], [-45.00,   -234.00], [-45.00,  -182.00], [-45.00,  -130.00], [-45.00,  -78.00], [-45.00,   -26.00], [-45.00,   26.00], [-45.00,    78.00], [-45.00,    130.00], [-45.00,   182.00], [-45.00,   234.00], [0.00, -208.00], [0.00,    -156.00], [0.00,    -104.00], [0.00,    -52.00], [0.00, 0.00], [0.00,   52.00], [0.00,  104.00], [0.00, 156.00], [0.00, 208.00], [45.00,    -234.00], [45.00,   -182.00], [45.00,   -130.00], [45.00,   -78.00], [45.00,    -26.00], [45.00,    26.00], [45.00, 78.00], [45.00, 130.00], [45.00,    182.00], [45.00,    234.00], [90.00,    -208.00], [90.00,   -156.00], [90.00,   -104.00], [90.00,   -52.00], [90.00,    0.00], [90.00,  52.00], [90.00, 104.00], [90.00,    156.00], [90.00,    208.00], [135.00,   -234.00], [135.00,  -182.00], [135.00,  -130.00], [135.00,  -78.00], [135.00,   -26.00], [135.00,   26.00], [135.00,    78.00], [135.00,    130.00], [135.00,   182.00], [135.00,   234.00], [180.00,   -208.00], [180.00,  -156.00], [180.00,  -104.00], [180.00,  -52.00], [180.00,   0.00], [180.00, 52.00], [180.00,    104.00], [180.00,   156.00], [180.00,   208.00], [225.00,   -182.00], [225.00,  -130.00], [225.00,  -78.00], [225.00,   -26.00], [225.00,   26.00], [225.00,    78.00], [225.00,    130.00], [225.00,   182.00]])

#checking number of arguments passed
if __name__ == "__main__":
    n = len(sys.argv)
    if n<3:
        print("Error number of arguments passed is incorrect. Try again...")

    print("argv[0]: {0}".format(argv[0]))
    print("argv[1]: {0}".format(argv[1]))
    input_file = argv[1] 
    output_file = argv[2] 
    
    #t0 = time.time()
    start = timer()
    hits = []
    with open(input_file, 'r') as data_in:
        for line in data_in:
            hits += [line.split()]
        
    results = process(hits)

    with open(output_file, 'a') as data_out:
        data_out.writelines(results)
        
    
    end = timer()
    print("Elapsed Time: ", timedelta(seconds = end-start))
        

我对这个脚本有疑问,但更准确地说,它与这部分有关:

hits = []
    with open(input_file, 'r') as data_in:
        for line in data_in:
            hits += [line.split()]

与我使用函数 int_or_float(s) :

相比,这种分割输入文件行的方式应该是非常低效的
 with open(input_file, 'r') as data_in:
        hits = [[int_or_float(s) for s in line.split()] for line in data_in]

然而,在使用多个输入文件(大小从几 MB 到几 GB 不等)进行测试后,我使用第一种方法(效率低下的方法)在计算时间方面得到了最好的结果。作为参考,分析大约 2.3GB 的相同输入文件,“低效”方法大约需要 50 分钟,而“功能”方法大约需要 55 分钟;在这两种情况下,我都使用相同的 Python 代码来分析输入文件,唯一的区别是我使用的方法,“低效”或“功能”。

知道为什么会这样吗?

感谢您的帮助!

更新 20/08

感谢@DarrylG,我得到了新版本的脚本,今天早上我有机会对其进行测试。我比较了我的最新版本,我 post 在我的原始 post 的 19/08 更新中编辑的版本,以及可以在接受的答案中找到的 DarrylG 的最新版本(版本是记为“更快”)。我比较了两个版本的脚本需要生成一个完整的输出文件的时间,测试是使用不同大小的输入文件完成的:

输入文件大小 224MB

输入文件大小 2.3GB

我还检查了生成的输出文件以查看是否存在任何差异,但它们是相同的。

尽管这远非深度测试,但我相信可以肯定地说 DarrylG 最新版本是最快的,尤其是在处理大尺寸输入文件时。

感谢大家的帮助,如果有人对脚本的性能有更多疑问,请随时与我联系。

更新的答案Post

int_or_float 版本花费时间较长的原因是您进行了两次数据转换。

  • 一次使用用户函数int_or_float
  • 然后在各个列上再次使用内置函数 int, float

实际上

hits += [line.split()]

并且:

hits.append([line.split()])

两者 运行 都以相同的速度进行,因为两者都是在适当的位置完成的(即我下面的原始评论对此不正确)。 参见 Why does += behave unexpectedly on lists?

比上一个 OP 版本更快

这比 OP 修订版略有加快(即 19 秒对我机器上的 22 秒)。

import sys
import numpy as np
import time
import numpy as np
import itertools
import sys
from sys import argv
import os
import time
from timeit import default_timer as timer
from datetime import timedelta

def get_data(input_file):
    '''
    Uses Numpy loadtxt for loading and data type conversion

    '''
    # Only load the columns of interest
    arr = np.loadtxt(input_file, usecols=[0, 1, 7, 10, 11, 12, 13], converters = {0:int, 1:int, 2:float, 3:float, 
                                                                                   4:float, 5:float, 6:int})
    
    # Place into individual arrays
    runid, eventid, time_fl, posx, posy, posz, partid = [arr[:, i] for i in range(arr.shape[1])]

    runid = runid.astype(int)
    eventid = eventid.astype(int)
    partid = partid.astype(int)

    # Convert from numpy to regular lists
    runid = runid.tolist()
    eventid = eventid.tolist()
    time_fl = time_fl.tolist()
    posx = posx.tolist()
    posy = posy.tolist()
    posz = posz.tolist()
    partid = partid.tolist()
    
    return runid, eventid, time_fl, posx, posy, posz, partid


def process(runid, eventid, time_fl, posx, posy, posz, partid):

    totalPMT = 108

    PMTs = [0]*108 #initial output signal
    nscinti, nscinti, nphoton, startscinti, zscinti = [0]*5

    Etres = 100 #energy treshold
    numb_22, numb_0 = 0, 0
    
    eventIDnow = [0]*len(eventid) 
    runIDnow = [0]*len(runid)

    # Previous value of eventID and runID
    eventIDnow = 0
    runIDnow = 0
    
    results = []
    for i in range(len(eventid)):
        if (eventid[i] > eventIDnow) or (runid[i] > runIDnow):
            if ((nscinti>0) and (partid[i]==22) and (nphoton>Etres)):
                print("Writing output file")
            
                results.append("{0} {1} {2}\n".format(startscinti, zscinti, ' '.join(map(str, PMTs))))


            eventIDnow = eventid[i]
            runIDnow = runid[i]

            startscinti = time_fl[i]
            zscinti = posz[i]
            nphoton = 0
            nscinti += 1
            PMTs = [0]*len(PMTs)
       
        #checking where the event get recorded
        k = next((k for k in range(len(PMTs)) 
                    if ((k<=totalPMT) and 
                        ((posx[i]-PMTloc[k][0])*(posx[i]-PMTloc[k][0]) + 
                         (posy[i]-PMTloc[k][1])*(posy[i]-PMTloc[k][1])<=23*23))), 
                None)
        if k:
            PMTs[k] += 1        
            nphoton += 1
    
    return results
                
#position of centres
PMTloc =np.array([[-270.00, -156.00], [-270.00, -104.00], [-270.00, -52.00], [-270.00,  0.00], [-270.00,    52.00], [-270.00,   104.00], [-270.00,  156.00], [-225.00,  -182.00], [-225.00, -130.00], [-225.00, -78.00], [-225.00,  -26.00], [-225.00,  26.00], [-225.00,   78.00], [-225.00,   130.00], [-225.00,  182.00], [-180.00,  -208.00], [-180.00, -156.00], [-180.00, -104.00], [-180.00, -52.00], [-180.00,  0.00], [-180.00,    52.00], [-180.00,   104.00], [-180.00,  156.00], [-180.00,  208.00], [-135.00,  -234.00], [-135.00, -182.00], [-135.00, -130.00], [-135.00, -78.00], [-135.00,  -26.00], [-135.00,  26.00], [-135.00,   78.00], [-135.00,   130.00], [-135.00,  182.00], [-135.00,  234.00], [-90.00,   -208.00], [-90.00,  -156.00], [-90.00,  -104.00], [-90.00,  -52.00], [-90.00,   0.00], [-90.00, 52.00], [-90.00,    104.00], [-90.00,   156.00], [-90.00,   208.00], [-45.00,   -234.00], [-45.00,  -182.00], [-45.00,  -130.00], [-45.00,  -78.00], [-45.00,   -26.00], [-45.00,   26.00], [-45.00,    78.00], [-45.00,    130.00], [-45.00,   182.00], [-45.00,   234.00], [0.00, -208.00], [0.00,    -156.00], [0.00,    -104.00], [0.00,    -52.00], [0.00, 0.00], [0.00,   52.00], [0.00,  104.00], [0.00, 156.00], [0.00, 208.00], [45.00,    -234.00], [45.00,   -182.00], [45.00,   -130.00], [45.00,   -78.00], [45.00,    -26.00], [45.00,    26.00], [45.00, 78.00], [45.00, 130.00], [45.00,    182.00], [45.00,    234.00], [90.00,    -208.00], [90.00,   -156.00], [90.00,   -104.00], [90.00,   -52.00], [90.00,    0.00], [90.00,  52.00], [90.00, 104.00], [90.00,    156.00], [90.00,    208.00], [135.00,   -234.00], [135.00,  -182.00], [135.00,  -130.00], [135.00,  -78.00], [135.00,   -26.00], [135.00,   26.00], [135.00,    78.00], [135.00,    130.00], [135.00,   182.00], [135.00,   234.00], [180.00,   -208.00], [180.00,  -156.00], [180.00,  -104.00], [180.00,  -52.00], [180.00,   0.00], [180.00, 52.00], [180.00,    104.00], [180.00,   156.00], [180.00,   208.00], [225.00,   -182.00], [225.00,  -130.00], [225.00,  -78.00], [225.00,   -26.00], [225.00,   26.00], [225.00,    78.00], [225.00,    130.00], [225.00,   182.00]])

#checking number of arguments passed
if __name__ == "__main__":
    n = len(sys.argv)
    if n<3:
        print("Error number of arguments passed is incorrect. Try again...")

    print("argv[0]: {0}".format(argv[0]))
    print("argv[1]: {0}".format(argv[1]))
    input_file = argv[1] 
    output_file = argv[2] 

    #t0 = time.time()
    start = timer()
    
    runid, eventid, time_fl, posx, posy, posz, partid = get_data(input_file)
    results = process(runid, eventid, time_fl, posx, posy, posz, partid)

    with open(output_file, 'a') as data_out:
        data_out.writelines(results)
        
    
    end = timer()
    print("Elapsed Time: ", timedelta(seconds = end-start))