Python:动画 3D 散点图变慢

Python: Animated 3D Scatterplot gets slow

我的程序在我的文件中为每个时间步绘制粒子的位置。不幸的是,它变得越来越慢,尽管我使用了 matplotlib.animation。瓶颈在哪里?

我的两个粒子数据文件如下所示:

#     x   y   z
# t1  1   2   4
#     4   1   3
# t2  4   0   4
#     3   2   9
# t3  ...

我的脚本:

import numpy as np                          
import matplotlib.pyplot as plt            
from mpl_toolkits.mplot3d import Axes3D
import mpl_toolkits.mplot3d.axes3d as p3
import matplotlib.animation as animation

# Number of particles
numP = 2
# Dimensions
DIM = 3
timesteps = 2000

with open('//home//data.dat', 'r') as fp:
    particleData = []
    for line in fp:
        line = line.split()
        particleData.append(line)

x = [float(item[0]) for item in particleData]
y = [float(item[1]) for item in particleData]
z = [float(item[2]) for item in particleData]      

# Attaching 3D axis to the figure
fig = plt.figure()
ax = p3.Axes3D(fig)

# Setting the axes properties
border = 1
ax.set_xlim3d([-border, border])
ax.set_ylim3d([-border, border])
ax.set_zlim3d([-border, border])


def animate(i):
    global x, y, z, numP
    #ax.clear()
    ax.set_xlim3d([-border, border])
    ax.set_ylim3d([-border, border])
    ax.set_zlim3d([-border, border])
    idx0 = i*numP
    idx1 = numP*(i+1)
    ax.scatter(x[idx0:idx1],y[idx0:idx1],z[idx0:idx1])

ani = animation.FuncAnimation(fig, animate, frames=timesteps, interval=1, blit=False, repeat=False)
plt.show()

我建议在这种情况下使用 pyqtgraph。来自文档的引用:

Its primary goals are 1) to provide fast, interactive graphics for displaying data (plots, video, etc.) and 2) to provide tools to aid in rapid application development (for example, property trees such as used in Qt Designer).

您可以在安装后查看一些示例:

import pyqtgraph.examples
pyqtgraph.examples.run()

这个小代码片段生成 1000 个随机点,并通过不断更新不透明度将它们显示在 3D 散点图中,类似于 pyqtgraph.examples:

中的 3D 散点图示例
from pyqtgraph.Qt import QtCore, QtGui
import pyqtgraph.opengl as gl
import numpy as np

app = QtGui.QApplication([])
w = gl.GLViewWidget()
w.show()
g = gl.GLGridItem()
w.addItem(g)

#generate random points from -10 to 10, z-axis positive
pos = np.random.randint(-10,10,size=(1000,3))
pos[:,2] = np.abs(pos[:,2])

sp2 = gl.GLScatterPlotItem(pos=pos)
w.addItem(sp2)

#generate a color opacity gradient
color = np.zeros((pos.shape[0],4), dtype=np.float32)
color[:,0] = 1
color[:,1] = 0
color[:,2] = 0.5
color[0:100,3] = np.arange(0,100)/100.

def update():
    ## update volume colors
    global color
    color = np.roll(color,1, axis=0)
    sp2.setData(color=color)

t = QtCore.QTimer()
t.timeout.connect(update)
t.start(50)


## Start Qt event loop unless running in interactive mode.
if __name__ == '__main__':
    import sys
    if (sys.flags.interactive != 1) or not hasattr(QtCore, PYQT_VERSION'):
        QtGui.QApplication.instance().exec_()

让您了解性能的小 gif:

编辑:

在每个时间步长显示多个点有点棘手,因为 gl.GLScatterPlotItem 仅将 (N,3)-arrays 作为点位置,请参阅 here。您可以尝试制作一个 ScatterPlotItems 的字典,其中每个字典都包含特定点的所有时间步长。然后需要相应地调整更新功能。您可以在下面找到一个示例,其中 pos 是一个 (100,10,3)-array,代表每个点的 100 个时间步长。我将更新时间减少到 1000 ms 以获得较慢的动画效果。

from pyqtgraph.Qt import QtCore, QtGui
import pyqtgraph.opengl as gl
import numpy as np

app = QtGui.QApplication([])
w = gl.GLViewWidget()
w.show()
g = gl.GLGridItem()
w.addItem(g)

pos = np.random.randint(-10,10,size=(100,10,3))
pos[:,:,2] = np.abs(pos[:,:,2])

ScatterPlotItems = {}
for point in np.arange(10):
    ScatterPlotItems[point] = gl.GLScatterPlotItem(pos=pos[:,point,:])
    w.addItem(ScatterPlotItems[point])

color = np.zeros((pos.shape[0],10,4), dtype=np.float32)
color[:,:,0] = 1
color[:,:,1] = 0
color[:,:,2] = 0.5
color[0:5,:,3] = np.tile(np.arange(1,6)/5., (10,1)).T

def update():
    ## update volume colors
    global color
    for point in np.arange(10):
        ScatterPlotItems[point].setData(color=color[:,point,:])
    color = np.roll(color,1, axis=0)

t = QtCore.QTimer()
t.timeout.connect(update)
t.start(1000)


## Start Qt event loop unless running in interactive mode.
if __name__ == '__main__':
    import sys
    if (sys.flags.interactive != 1) or not hasattr(QtCore, 'PYQT_VERSION'):
    QtGui.QApplication.instance().exec_()

请记住,在此示例中,所有 点都显示在散点图中,但是,颜色不透明度(颜色数组中的第 4 维)每次都会更新获取动画的步骤。您也可以尝试更新点而不是颜色以获得更好的性能...

我猜你的瓶颈是在动画的每一帧调用 ax.scatterax.set_xlim3d 以及类似的东西。

理想情况下,您应该调用 scatter 一次,然后在 animate 函数中使用 scatter 返回的对象及其 set_... 属性 (more details here) .

我不知道如何使用 scatter 来实现,但是如果您改用 ax.plot(x, y, z, 'o'),则可以按照演示方法 here.

使用一些随机数据 x, y, z。它会像这样工作

import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import mpl_toolkits.mplot3d.axes3d as p3
import matplotlib.animation as animation
from numpy.random import random

# Number of particles
numP = 2
# Dimensions
DIM = 3
timesteps = 2000

x, y, z = random(timesteps), random(timesteps), random(timesteps)

# Attaching 3D axis to the figure
fig = plt.figure()
ax = p3.Axes3D(fig)

# Setting the axes properties
border = 1
ax.set_xlim3d([-border, border])
ax.set_ylim3d([-border, border])
ax.set_zlim3d([-border, border])
line = ax.plot(x[:1], y[:1], z[:1], 'o')[0]


def animate(i):
    global x, y, z, numP
    idx1 = numP*(i+1)
    # join x and y into single 2 x N array
    xy_data = np.c_[x[:idx1], y[:idx1]].T
    line.set_data(xy_data)
    line.set_3d_properties(z[:idx1])

ani = animation.FuncAnimation(fig, animate, frames=timesteps, interval=1, blit=False, repeat=False)
plt.show()