Matplotlib:pandas MultiIndex DataFrame 的自定义代码
Matplotlib: custom ticker for pandas MultiIndex DataFrame
我有一个很大的 pandas MultiIndex DataFrame 想要绘制。一个最小的例子看起来像:
import pandas as pd
years = range(2015, 2018)
fields = range(4)
days = range(4)
bands = ['R', 'G', 'B']
index = pd.MultiIndex.from_product(
[years, fields], names=['year', 'field'])
columns = pd.MultiIndex.from_product(
[days, bands], names=['day', 'band'])
df = pd.DataFrame(0, index=index, columns=columns)
df.loc[(2015,), (0,)] = 1
df.loc[(2016,), (1,)] = 1
df.loc[(2017,), (2,)] = 1
如果我使用 plt.spy
绘制此图,我得到:
但是,刻度位置和标签不太理想。我希望滴答声完全忽略 MultiIndex 的第二级。使用 IndexLocator
and IndexFormatter
,我可以执行以下操作:
from matplotlib.ticker import IndexFormatter, IndexLocator
import matplotlib.pyplot as plt
ax = plt.gca()
plt.spy(df)
xbase = len(bands)
xoffset = xbase / 2
xlabels = df.columns.get_level_values('day')
ax.xaxis.set_major_locator(IndexLocator(base=xbase, offset=xoffset))
ax.xaxis.set_major_formatter(IndexFormatter(xlabels))
plt.xlabel('Day')
ax.xaxis.tick_bottom()
ybase = len(fields)
yoffset = ybase / 2
ylabels = df.index.get_level_values('year')
ax.yaxis.set_major_locator(IndexLocator(base=ybase, offset=yoffset))
ax.yaxis.set_major_formatter(IndexFormatter(ylabels))
plt.ylabel('Year')
plt.show()
这正是我想要的:
但这就是问题所在。我的实际 DataFrame 有 15 年、4,000 个字段、365 天和 7 个波段。如果我真的每天都贴标签,标签将难以辨认。我可以每 50 天放置一个刻度,但我希望刻度是动态的,这样当我放大时,刻度会变得更细粒度。基本上我正在寻找的是一个自定义 MultiIndexLocator
,它结合了 IndexLocator
的位置和 MaxNLocator
.
的活力
奖励:从某种意义上说,我的数据非常好,每年总是有相同数量的字段,每天都有相同数量的波段。但如果情况并非如此呢?我很乐意为 matplotlib
贡献一个适用于任何 MultiIndex DataFrame 的通用 MultiIndexLocator
和 MultiIndexFormatter
。
Matplotlib 不了解数据帧或多索引。它只是绘制您提供的数据。 IE。你得到的结果就好像你在绘制 numpy 数据数组一样,spy(df.values)
.
所以我建议首先正确设置图像的范围,以便您可以使用数字代码。那么 MaxNLocator
应该可以正常工作,除非你没有放大太多。
import numpy as np
import pandas as pd
from matplotlib.ticker import MaxNLocator
import matplotlib.pyplot as plt
plt.rcParams['axes.formatter.useoffset'] = False
years = range(2000, 2018)
fields = range(9) #17
days = range(120) #365
bands = ['R', 'G', 'B', 'A']
index = pd.MultiIndex.from_product(
[years, fields], names=['year', 'field'])
columns = pd.MultiIndex.from_product(
[days, bands], names=['day', 'band'])
data = np.random.rand(len(years)*len(fields),len(days)*len(bands))
x,y = np.meshgrid(np.arange(data.shape[1]),np.arange(data.shape[0]))
data += 2*((y//len(fields)+x//len(bands)) % 2)
df = pd.DataFrame(data, index=index, columns=columns)
############
# Plotting
############
xbase = len(bands)
xlabels = df.columns.get_level_values('day')
ybase = len(fields)
ylabels = df.index.get_level_values('year')
extent = [xlabels.min()-np.diff(np.unique(xlabels))[0]/2.,
xlabels.max()+np.diff(np.unique(xlabels))[0]/2.,
ylabels.min()-np.diff(np.unique(ylabels))[0]/2.,
ylabels.max()+np.diff(np.unique(ylabels))[0]/2.,]
fig, ax = plt.subplots()
ax.imshow(df.values, extent=extent, aspect="auto")
ax.set_ylabel('Year')
ax.set_xlabel('Day')
ax.xaxis.set_major_locator(MaxNLocator(integer=True,min_n_ticks=1))
ax.yaxis.set_major_locator(MaxNLocator(integer=True,min_n_ticks=1))
plt.show()
我有一个很大的 pandas MultiIndex DataFrame 想要绘制。一个最小的例子看起来像:
import pandas as pd
years = range(2015, 2018)
fields = range(4)
days = range(4)
bands = ['R', 'G', 'B']
index = pd.MultiIndex.from_product(
[years, fields], names=['year', 'field'])
columns = pd.MultiIndex.from_product(
[days, bands], names=['day', 'band'])
df = pd.DataFrame(0, index=index, columns=columns)
df.loc[(2015,), (0,)] = 1
df.loc[(2016,), (1,)] = 1
df.loc[(2017,), (2,)] = 1
如果我使用 plt.spy
绘制此图,我得到:
但是,刻度位置和标签不太理想。我希望滴答声完全忽略 MultiIndex 的第二级。使用 IndexLocator
and IndexFormatter
,我可以执行以下操作:
from matplotlib.ticker import IndexFormatter, IndexLocator
import matplotlib.pyplot as plt
ax = plt.gca()
plt.spy(df)
xbase = len(bands)
xoffset = xbase / 2
xlabels = df.columns.get_level_values('day')
ax.xaxis.set_major_locator(IndexLocator(base=xbase, offset=xoffset))
ax.xaxis.set_major_formatter(IndexFormatter(xlabels))
plt.xlabel('Day')
ax.xaxis.tick_bottom()
ybase = len(fields)
yoffset = ybase / 2
ylabels = df.index.get_level_values('year')
ax.yaxis.set_major_locator(IndexLocator(base=ybase, offset=yoffset))
ax.yaxis.set_major_formatter(IndexFormatter(ylabels))
plt.ylabel('Year')
plt.show()
这正是我想要的:
但这就是问题所在。我的实际 DataFrame 有 15 年、4,000 个字段、365 天和 7 个波段。如果我真的每天都贴标签,标签将难以辨认。我可以每 50 天放置一个刻度,但我希望刻度是动态的,这样当我放大时,刻度会变得更细粒度。基本上我正在寻找的是一个自定义 MultiIndexLocator
,它结合了 IndexLocator
的位置和 MaxNLocator
.
奖励:从某种意义上说,我的数据非常好,每年总是有相同数量的字段,每天都有相同数量的波段。但如果情况并非如此呢?我很乐意为 matplotlib
贡献一个适用于任何 MultiIndex DataFrame 的通用 MultiIndexLocator
和 MultiIndexFormatter
。
Matplotlib 不了解数据帧或多索引。它只是绘制您提供的数据。 IE。你得到的结果就好像你在绘制 numpy 数据数组一样,spy(df.values)
.
所以我建议首先正确设置图像的范围,以便您可以使用数字代码。那么 MaxNLocator
应该可以正常工作,除非你没有放大太多。
import numpy as np
import pandas as pd
from matplotlib.ticker import MaxNLocator
import matplotlib.pyplot as plt
plt.rcParams['axes.formatter.useoffset'] = False
years = range(2000, 2018)
fields = range(9) #17
days = range(120) #365
bands = ['R', 'G', 'B', 'A']
index = pd.MultiIndex.from_product(
[years, fields], names=['year', 'field'])
columns = pd.MultiIndex.from_product(
[days, bands], names=['day', 'band'])
data = np.random.rand(len(years)*len(fields),len(days)*len(bands))
x,y = np.meshgrid(np.arange(data.shape[1]),np.arange(data.shape[0]))
data += 2*((y//len(fields)+x//len(bands)) % 2)
df = pd.DataFrame(data, index=index, columns=columns)
############
# Plotting
############
xbase = len(bands)
xlabels = df.columns.get_level_values('day')
ybase = len(fields)
ylabels = df.index.get_level_values('year')
extent = [xlabels.min()-np.diff(np.unique(xlabels))[0]/2.,
xlabels.max()+np.diff(np.unique(xlabels))[0]/2.,
ylabels.min()-np.diff(np.unique(ylabels))[0]/2.,
ylabels.max()+np.diff(np.unique(ylabels))[0]/2.,]
fig, ax = plt.subplots()
ax.imshow(df.values, extent=extent, aspect="auto")
ax.set_ylabel('Year')
ax.set_xlabel('Day')
ax.xaxis.set_major_locator(MaxNLocator(integer=True,min_n_ticks=1))
ax.yaxis.set_major_locator(MaxNLocator(integer=True,min_n_ticks=1))
plt.show()