我的代码应该模拟 alpha 衰减的平均能量,它可以作业,但速度很慢。
import numpy as np
from numpy import sin, cos, arccos, pi, arange, fromiter
import matplotlib.pyplot as plt
from random import choices
r_cell, d, r, R, N = 5.5, 15.8, 7.9, 20, arange(1,10000, 50)
def total_decay(N):
theta = 2*pi*np.random.rand(2,N)
phi = arccos(2*np.random.rand(2,N)-1)
x = fromiter((r*sin(phi[0][i])*cos(theta[0][i]) for i in range(N)),float, count=-1)
dx = fromiter((x[i] R*sin(phi[1][i])*cos(theta[1][i]) for i in range(N)), float,count=-1)
y = fromiter((r*sin(phi[0][i])*sin(theta[0][i]) for i in range(N)),float, count=-1)
dy = fromiter((y[i] R*sin(phi[1][i])*sin(theta[1][i]) for i in range(N)),float,count=-1)
z = fromiter((r*cos(phi[0][i]) for i in range(N)),float, count=-1)
dz = fromiter((z[i] R*cos(phi[1][i]) for i in range(N)),float, count=-1)
return x, y, z, dx, dy, dz
def inter(x,y,z,dx,dy,dz, N):
intersections = 0
for i in range(N): #Checks to see if a line between two points intersects with the target cell
a = (dx[i] - x[i])*(dx[i] - x[i]) (dy[i] - y[i])*(dy[i] - y[i]) (dz[i] - z[i])*(dz[i] - z[i])
b = 2*((dx[i] - x[i])*(x[i]-d) (dy[i] - y[i])*(y[i]) (dz[i] - z[i])*(z[i]))
c = d*d x[i]*x[i] y[i]*y[i] z[i]*z[i] - 2*(d*x[i]) - r_cell*r_cell
if b*b - 4*a*c >= 0:
intersections = 1
return intersections
def hits(N):
I = []
for i in range(len(N)):
decay = total_decay(N[i])
I.append(inter(decay[0],decay[1],decay[2],decay[3],decay[4],decay[5],N[i]))
return I
def AE(I,N):
p1, p2 = 52.4 / (52.4 18.9), 18.9 / (52.4 18.9)
E = [choices([5829.6, 5793.1], cum_weights=(p1,p2),k=1)[0] for _ in range(I)]
return sum(E)/N
def list_AE(I,N):
E = [AE(I[i],N[i]) for i in range(len(N))]
return E
plt.plot(N, list_AE(hits(N),N))
plt.title('Average energy per dose with respect to number of decays')
plt.xlabel('Number of decays [N]')
plt.ylabel('Average energy [keV]')
plt.show()
任何有经验的人都可以指出瓶颈发生在哪里,解释它为什么发生以及如何优化它?提前致谢。
uj5u.com热心网友回复:
要找出大部分时间花在您的代码中,请使用分析器检查它。通过像这样包装你的主要代码:
import cProfile
import pstats
profiler = cProfile.Profile()
profiler.enable()
result = list_AE(hits(N), N)
profiler.disable()
stats = pstats.Stats(profiler).sort_stats('tottime')
stats.print_stats()
您将获得以下概述(缩写):
6467670 function calls in 19.982 seconds
Ordered by: internal time
ncalls tottime percall cumtime percall filename:lineno(function)
200 4.766 0.024 4.766 0.024 ./alphadecay.py:24(inter)
995400 2.980 0.000 2.980 0.000 ./alphadecay.py:17(<genexpr>)
995400 2.925 0.000 2.925 0.000 ./alphadecay.py:15(<genexpr>)
995400 2.690 0.000 2.690 0.000 ./alphadecay.py:16(<genexpr>)
995400 2.683 0.000 2.683 0.000 ./alphadecay.py:14(<genexpr>)
995400 1.674 0.000 1.674 0.000 ./alphadecay.py:19(<genexpr>)
995400 1.404 0.000 1.404 0.000 ./alphadecay.py:18(<genexpr>)
1200 0.550 0.000 14.907 0.012 {built-in method numpy.fromiter}
大部分时间都花在inter
函式上,因为它运行了一个巨大的回圈N
。为了改善这一点,您可以使用multiprocessing.Pool
.
另一种加速计算的方法是使用 NumPy 矢量化。也就是说,避免N
在total_decay()
函式内部迭代:
def total_decay(N):
theta = 2 * pi * np.random.rand(2, N)
phi = arccos(2 * np.random.rand(2, N) - 1)
x = r * sin(phi[0]) * cos(theta[0])
y = r * sin(phi[0]) * sin(theta[0])
z = r * cos(phi[0])
dx = x R * sin(phi[1]) * cos(theta[1])
dy = y R * sin(phi[1]) * sin(theta[1])
dz = z R * cos(phi[1])
return x, y, z, dx, dy, dz
我对代码进行了一些整理,使其更具可读性。在这一点上,我强烈建议您遵循 Python 格式约定并使用描述性变量名称使您的代码更易于理解。
uj5u.com热心网友回复:
我不会告诉你瓶颈在哪里,但我可以告诉你如何在复杂的程序中找到瓶颈。关键字是分析。分析器是一个应用程序,它将与您的代码一起运行并测量每个陈述句的执行时间。在线搜索 python 分析器。
穷人的版本将除错和猜测陈述句的执行时间或使用打印陈述句或用于测量执行时间的库。不过,使用分析器是一项并不难学习的重要技能。
uj5u.com热心网友回复:
您应该尽可能避免附加(您在命中中使用它)并使用串列推导式或已经构建的串列代替(如您在 list_AE 中使用的)。我建议您构建一个串列(具有所需的长度),然后只需按其索引填充每个单元格。
0 评论