Python CS50x 的 DNA 问题中 运行 代码花费的时间太长

Python code takes too long to run in DNA problem from CS50x

我写了一个代码来解决 CS50 第 6 周的 DNA 问题。但是,当我在 运行 上 large.csv 数据库和序列时,至少需要一分钟才能生成输出。在 small.csv 上,它会立即产生输出。因此,我无法通过 check50。我想问题出在用于生成 STR 的最大重复序列数的函数阶段,但我不知道如何更有效地编写它。问题的完整描述在这里:https://cs50.harvard.edu/x/2021/psets/6/dna/#:~:text=check50%20cs50/problems/2021/x/dna

下面是数据库和序列的源文件: https://cdn.cs50.net/2019/fall/psets/6/dna/

这是我的代码:

import csv
import sys


def main():
    
    # check a proper input
    if len(sys.argv) != 3:
        sys.exit("Usage: python dna.py data.csv sequence.txt")
    
    # create a list for all data
    data_all = []
    
    # create a list for all STRs
    STR_all = []
    
    # write data to list
    with(open(sys.argv[1])) as data:
        reader = csv.DictReader(data)
        for row in reader:
            row["name"]
            data_all.append(row)
            
    # write header to a list 
    with(open(sys.argv[1])) as data:      
        reader = csv.reader(data)
        headings = next(reader)
        STR_all.append(headings)
    
    # delete "name" from header, it is on the first position    
    STR_all = STR_all[0]
    STR_all.pop(0)
            
    # create a string with DNA sequence
    with(open(sys.argv[2])) as seq:
        line = seq.read()
    
    # create a list with max number of repeating STR from a line(DNA)
    max_seq = []
    
    # enter data with string of STR and it's max repeating time    
    for i in range(len(STR_all)):
        result = f"{compare(STR_all[i], line)}"
        max_seq.append(result)
        
    # create a dictionary with a list of all STRs and according number of repeating sequences
    STR_with_max_seq = dict(zip(STR_all, max_seq))
    
    # compare values from data_all and STR_with_max_seq
    for i in range(len(data_all)):
        # delete name key and store key in variable "name"
        name = data_all[i].pop('name')
        if data_all[i] == STR_with_max_seq:
            print(name)
            sys.exit()
            break
        else:
            continue
        
    # Print if no match found
    print("No match")
        
    # variables that I used to check on different stages of writing a program
            
    # print(data_all)
    # print(line)
    # print(STR_all)
    # print(max_seq)
    # print(STR_with_max_seq)
    
    # print(len(data_all))
    # print(name)

    
def compare(STR, DNA):

    for key in DNA:
        l = len(STR)
        tmp_max = 0
        tmp = 0
        
        # iteration through the whole length of DNA
        for i in range(len(DNA)):
            if tmp > 0:
                tmp = 0
            
            # enters if sequences are equal
            if DNA[i: i + l] == STR:
                tmp += 1
                # increments tmp if its sequence repeats
                while DNA[i - l: i] == DNA[i: i + l]:
                    tmp += 1
                    i += l
                # update the max found number of repeating sequences    
                if tmp > tmp_max:
                    tmp_max = tmp
    
    return tmp_max

    
main()

更新:我已使用 time.monotonic() 检查 main() 中代码执行的总时间。现在是 small.csv:

的时间

这是 large.csv:

我知道那个问题。您的代码的某些部分使其变慢。

首先,让我们尝试只读取每个文件一次。例如:

with(open(sys.argv[1])) as data:
    reader = csv.DictReader(data)
    STR_all = reader.fieldnames
    for row in reader:
        row["name"]
        data_all.append(row)

STR_all 将是一个列表,因此您可以删除行:

STR_all = STR_all[0]

比较的时候可以建立计数器的字典,这样就可以避免重复两次。

例如,通过这样做:

    # enter data with string of STR and it's max repeating time    
    for i in range(len(STR_all)):
        key = STR_all[i]
        STR_with_max_seq[key] = str(compare(key, line))

你可以删除这个:

    # create a dictionary with a list of all STRs and according number of repeating sequences
    STR_with_max_seq = dict(zip(STR_all, max_seq))

最后,改进compare函数,可以避免第一个循环。您想要找出 DNA 中 STR 连续出现的最大次数。因此,您只需按 STR 长度的 windows 遍历 DNA 并比较它们。例如:

def compare(STR, DNA):
    l = len(STR)
    tmp_max = 0
    tmp = 0
    i = 0
        
    # iteration through the whole length of DNA
    while i < len(DNA) - l:  # make sure the last str has length = l
        SSTR = DNA[i : i + l]  # Extract a substre of length l

        if SSTR == STR:
            # You can jump l positions here.
            i += l
            tmp += 1
        else:
            i += 1
            if tmp > tmp_max:
                tmp_max = tmp
            tmp = 0                
    
    return tmp_max