matplotlib:绘制 3D 轴中的路径

matplotlib: Plotting the path in 3D axis

我想根据 x y z 位置数据绘制路径。下面是一个可重现的例子,所有的行都从 0 开始,而不是一个接一个地跟随。

import seaborn as sns
# loading sample data and replicating my scenario 
data = sns.load_dataset("iris")
# giving it a numeric value to replicate my scenario 
cat_lbl = {'setosa': 1, 'versicolor': 2,'virginica' : 3}
data['cat_lbl'] = data['species'].map(cat_lbl)


#plot headings
species = ['setosa', 'versicolor', 'virginica']


import matplotlib as mpl
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt

sepal_length = data.loc[:,['sepal_length','cat_lbl']]
sepal_width = data.loc[:,['sepal_width','cat_lbl']]
petal_length = data.loc[:,['petal_length','cat_lbl']]

fig = plt.figure(figsize=([20,15]))
for lbl in range(3):
    lbl=lbl+1
    x=sepal_length[(sepal_length.cat_lbl == lbl)].values
    y=sepal_width[(sepal_width.cat_lbl == lbl)].values
    z=petal_length[(petal_length.cat_lbl == lbl)].values


    ax=fig.add_subplot(3,3,lbl, projection='3d')
    ax.plot(x.flatten(),y.flatten(),z.flatten())
    ax.set_title(species[lbl-1])
plt.show()

尝试 ax.plot3D(...) 而不是 ax.plot(...),如 this 3D 绘图教程中所述:

ax = plt.axes(projection='3d')

# Data for a three-dimensional line
zline = np.linspace(0, 15, 1000)
xline = np.sin(zline)
yline = np.cos(zline)
ax.plot3D(xline, yline, zline, 'gray')

# Data for three-dimensional scattered points
zdata = 15 * np.random.random(100)
xdata = np.sin(zdata) + 0.1 * np.random.randn(100)
ydata = np.cos(zdata) + 0.1 * np.random.randn(100)
ax.scatter3D(xdata, ydata, zdata, c=zdata, cmap='Greens');

你的问题是

x=sepal_length[(sepal_length.cat_lbl == lbl)].values
y=sepal_width[(sepal_width.cat_lbl == lbl)].values
z=petal_length[(petal_length.cat_lbl == lbl)].values

实际上是包含类别索引 (1,2,3) 的二维数组。因此,当您展平 x.flatten() 时,您会在坐标和类别索引之间交替(您可以看到这些线实际上在第一个图表上循环回到 (1,1),在第二个图表上循环回到 (2,2) 和 ( 3,3) 第三)

我会这样写你的代码:

import seaborn as sns
import matplotlib as mpl
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt

data = sns.load_dataset("iris")
species = ['setosa', 'versicolor', 'virginica']

fig,axs = plt.subplots(1,3,subplot_kw=dict(projection='3d'),figsize=(9,3))
for sp,ax in zip(species, axs.flat):
    temp = data.loc[data['species']==sp]
    x=temp['sepal_length'].values
    y=temp['sepal_width'].values
    z=temp['petal_length'].values

    ax.plot(x,y,z)
    ax.set_title(sp)
plt.show()