Python 代码有很大的瓶颈,但是我经验不够,看不出是哪里
Python code has a big bottleneck, but I am not experienced enough to see where it is
我的代码应该为 alpha 衰变的平均能量建模,它有效但速度很慢。
import numpy as np
from numpy import sin, cos, arccos, pi, arange, fromiter
import matplotlib.pyplot as plt
from random import choices
r_cell, d, r, R, N = 5.5, 15.8, 7.9, 20, arange(1,10000, 50)
def total_decay(N):
theta = 2*pi*np.random.rand(2,N)
phi = arccos(2*np.random.rand(2,N)-1)
x = fromiter((r*sin(phi[0][i])*cos(theta[0][i]) for i in range(N)),float, count=-1)
dx = fromiter((x[i] + R*sin(phi[1][i])*cos(theta[1][i]) for i in range(N)), float,count=-1)
y = fromiter((r*sin(phi[0][i])*sin(theta[0][i]) for i in range(N)),float, count=-1)
dy = fromiter((y[i] + R*sin(phi[1][i])*sin(theta[1][i]) for i in range(N)),float,count=-1)
z = fromiter((r*cos(phi[0][i]) for i in range(N)),float, count=-1)
dz = fromiter((z[i] + R*cos(phi[1][i]) for i in range(N)),float, count=-1)
return x, y, z, dx, dy, dz
def inter(x,y,z,dx,dy,dz, N):
intersections = 0
for i in range(N): #Checks to see if a line between two points intersects with the target cell
a = (dx[i] - x[i])*(dx[i] - x[i]) + (dy[i] - y[i])*(dy[i] - y[i]) + (dz[i] - z[i])*(dz[i] - z[i])
b = 2*((dx[i] - x[i])*(x[i]-d) + (dy[i] - y[i])*(y[i])+(dz[i] - z[i])*(z[i]))
c = d*d + x[i]*x[i] + y[i]*y[i] + z[i]*z[i] - 2*(d*x[i]) - r_cell*r_cell
if b*b - 4*a*c >= 0:
intersections += 1
return intersections
def hits(N):
I = []
for i in range(len(N)):
decay = total_decay(N[i])
I.append(inter(decay[0],decay[1],decay[2],decay[3],decay[4],decay[5],N[i]))
return I
def AE(I,N):
p1, p2 = 52.4 / (52.4 + 18.9), 18.9 / (52.4 + 18.9)
E = [choices([5829.6, 5793.1], cum_weights=(p1,p2),k=1)[0] for _ in range(I)]
return sum(E)/N
def list_AE(I,N):
E = [AE(I[i],N[i]) for i in range(len(N))]
return E
plt.plot(N, list_AE(hits(N),N))
plt.title('Average energy per dose with respect to number of decays')
plt.xlabel('Number of decays [N]')
plt.ylabel('Average energy [keV]')
plt.show()
哪位有经验的能指出瓶颈在哪里,解释为什么会发生,如何优化?提前致谢。
我不会告诉你瓶颈在哪里,但是我可以告诉你如何在复杂的程序中找到瓶颈。关键字是分析。探查器是一个应用程序,它将 运行 与您的代码一起并测量每个语句的执行时间。在线搜索 python 探查器。
穷人的版本是调试和猜测语句的执行时间,或者使用打印语句或库来测量执行时间。不过,使用分析器是一项并不难学的重要技能。
要找出大部分时间花在代码中的什么地方,请使用 profiler 检查它。通过像这样包装您的主要代码:
import cProfile
import pstats
profiler = cProfile.Profile()
profiler.enable()
result = list_AE(hits(N), N)
profiler.disable()
stats = pstats.Stats(profiler).sort_stats('tottime')
stats.print_stats()
您将获得以下概述(缩写):
6467670 function calls in 19.982 seconds
Ordered by: internal time
ncalls tottime percall cumtime percall filename:lineno(function)
200 4.766 0.024 4.766 0.024 ./alphadecay.py:24(inter)
995400 2.980 0.000 2.980 0.000 ./alphadecay.py:17(<genexpr>)
995400 2.925 0.000 2.925 0.000 ./alphadecay.py:15(<genexpr>)
995400 2.690 0.000 2.690 0.000 ./alphadecay.py:16(<genexpr>)
995400 2.683 0.000 2.683 0.000 ./alphadecay.py:14(<genexpr>)
995400 1.674 0.000 1.674 0.000 ./alphadecay.py:19(<genexpr>)
995400 1.404 0.000 1.404 0.000 ./alphadecay.py:18(<genexpr>)
1200 0.550 0.000 14.907 0.012 {built-in method numpy.fromiter}
大部分时间花在 inter
函数上,因为它在 N
上运行一个巨大的循环。为了改善这一点,您可以使用 multiprocessing.Pool
.
将其执行并行化到多个线程
另一种加快计算速度的方法是使用 NumPy 向量化。也就是说,避免在 total_decay()
函数内迭代 N
:
def total_decay(N):
theta = 2 * pi * np.random.rand(2, N)
phi = arccos(2 * np.random.rand(2, N) - 1)
x = r * sin(phi[0]) * cos(theta[0])
y = r * sin(phi[0]) * sin(theta[0])
z = r * cos(phi[0])
dx = x + R * sin(phi[1]) * cos(theta[1])
dy = y + R * sin(phi[1]) * sin(theta[1])
dz = z + R * cos(phi[1])
return x, y, z, dx, dy, dz
我对代码进行了一些安排以使其更具可读性。关于这一点,我强烈建议您遵循 Python 格式约定并使用描述性变量名称以使您的代码更易于理解。
我的代码应该为 alpha 衰变的平均能量建模,它有效但速度很慢。
import numpy as np
from numpy import sin, cos, arccos, pi, arange, fromiter
import matplotlib.pyplot as plt
from random import choices
r_cell, d, r, R, N = 5.5, 15.8, 7.9, 20, arange(1,10000, 50)
def total_decay(N):
theta = 2*pi*np.random.rand(2,N)
phi = arccos(2*np.random.rand(2,N)-1)
x = fromiter((r*sin(phi[0][i])*cos(theta[0][i]) for i in range(N)),float, count=-1)
dx = fromiter((x[i] + R*sin(phi[1][i])*cos(theta[1][i]) for i in range(N)), float,count=-1)
y = fromiter((r*sin(phi[0][i])*sin(theta[0][i]) for i in range(N)),float, count=-1)
dy = fromiter((y[i] + R*sin(phi[1][i])*sin(theta[1][i]) for i in range(N)),float,count=-1)
z = fromiter((r*cos(phi[0][i]) for i in range(N)),float, count=-1)
dz = fromiter((z[i] + R*cos(phi[1][i]) for i in range(N)),float, count=-1)
return x, y, z, dx, dy, dz
def inter(x,y,z,dx,dy,dz, N):
intersections = 0
for i in range(N): #Checks to see if a line between two points intersects with the target cell
a = (dx[i] - x[i])*(dx[i] - x[i]) + (dy[i] - y[i])*(dy[i] - y[i]) + (dz[i] - z[i])*(dz[i] - z[i])
b = 2*((dx[i] - x[i])*(x[i]-d) + (dy[i] - y[i])*(y[i])+(dz[i] - z[i])*(z[i]))
c = d*d + x[i]*x[i] + y[i]*y[i] + z[i]*z[i] - 2*(d*x[i]) - r_cell*r_cell
if b*b - 4*a*c >= 0:
intersections += 1
return intersections
def hits(N):
I = []
for i in range(len(N)):
decay = total_decay(N[i])
I.append(inter(decay[0],decay[1],decay[2],decay[3],decay[4],decay[5],N[i]))
return I
def AE(I,N):
p1, p2 = 52.4 / (52.4 + 18.9), 18.9 / (52.4 + 18.9)
E = [choices([5829.6, 5793.1], cum_weights=(p1,p2),k=1)[0] for _ in range(I)]
return sum(E)/N
def list_AE(I,N):
E = [AE(I[i],N[i]) for i in range(len(N))]
return E
plt.plot(N, list_AE(hits(N),N))
plt.title('Average energy per dose with respect to number of decays')
plt.xlabel('Number of decays [N]')
plt.ylabel('Average energy [keV]')
plt.show()
哪位有经验的能指出瓶颈在哪里,解释为什么会发生,如何优化?提前致谢。
我不会告诉你瓶颈在哪里,但是我可以告诉你如何在复杂的程序中找到瓶颈。关键字是分析。探查器是一个应用程序,它将 运行 与您的代码一起并测量每个语句的执行时间。在线搜索 python 探查器。
穷人的版本是调试和猜测语句的执行时间,或者使用打印语句或库来测量执行时间。不过,使用分析器是一项并不难学的重要技能。
要找出大部分时间花在代码中的什么地方,请使用 profiler 检查它。通过像这样包装您的主要代码:
import cProfile
import pstats
profiler = cProfile.Profile()
profiler.enable()
result = list_AE(hits(N), N)
profiler.disable()
stats = pstats.Stats(profiler).sort_stats('tottime')
stats.print_stats()
您将获得以下概述(缩写):
6467670 function calls in 19.982 seconds
Ordered by: internal time
ncalls tottime percall cumtime percall filename:lineno(function)
200 4.766 0.024 4.766 0.024 ./alphadecay.py:24(inter)
995400 2.980 0.000 2.980 0.000 ./alphadecay.py:17(<genexpr>)
995400 2.925 0.000 2.925 0.000 ./alphadecay.py:15(<genexpr>)
995400 2.690 0.000 2.690 0.000 ./alphadecay.py:16(<genexpr>)
995400 2.683 0.000 2.683 0.000 ./alphadecay.py:14(<genexpr>)
995400 1.674 0.000 1.674 0.000 ./alphadecay.py:19(<genexpr>)
995400 1.404 0.000 1.404 0.000 ./alphadecay.py:18(<genexpr>)
1200 0.550 0.000 14.907 0.012 {built-in method numpy.fromiter}
大部分时间花在 inter
函数上,因为它在 N
上运行一个巨大的循环。为了改善这一点,您可以使用 multiprocessing.Pool
.
另一种加快计算速度的方法是使用 NumPy 向量化。也就是说,避免在 total_decay()
函数内迭代 N
:
def total_decay(N):
theta = 2 * pi * np.random.rand(2, N)
phi = arccos(2 * np.random.rand(2, N) - 1)
x = r * sin(phi[0]) * cos(theta[0])
y = r * sin(phi[0]) * sin(theta[0])
z = r * cos(phi[0])
dx = x + R * sin(phi[1]) * cos(theta[1])
dy = y + R * sin(phi[1]) * sin(theta[1])
dz = z + R * cos(phi[1])
return x, y, z, dx, dy, dz
我对代码进行了一些安排以使其更具可读性。关于这一点,我强烈建议您遵循 Python 格式约定并使用描述性变量名称以使您的代码更易于理解。