Matplotlib:pandas MultiIndex DataFrame 的自定义代码

Matplotlib: custom ticker for pandas MultiIndex DataFrame

我有一个很大的 pandas MultiIndex DataFrame 想要绘制。一个最小的例子看起来像:

import pandas as pd

years = range(2015, 2018)
fields = range(4)
days = range(4)
bands = ['R', 'G', 'B']

index = pd.MultiIndex.from_product(
    [years, fields], names=['year', 'field'])
columns = pd.MultiIndex.from_product(
    [days, bands], names=['day', 'band'])

df = pd.DataFrame(0, index=index, columns=columns)

df.loc[(2015,), (0,)] = 1
df.loc[(2016,), (1,)] = 1
df.loc[(2017,), (2,)] = 1

如果我使用 plt.spy 绘制此图,我得到:

但是,刻度位置和标签不太理想。我希望滴答声完全忽略 MultiIndex 的第二级。使用 IndexLocator and IndexFormatter,我可以执行以下操作:

from matplotlib.ticker import IndexFormatter, IndexLocator

import matplotlib.pyplot as plt

ax = plt.gca()
plt.spy(df)

xbase = len(bands)
xoffset = xbase / 2
xlabels = df.columns.get_level_values('day')
ax.xaxis.set_major_locator(IndexLocator(base=xbase, offset=xoffset))
ax.xaxis.set_major_formatter(IndexFormatter(xlabels))
plt.xlabel('Day')
ax.xaxis.tick_bottom()

ybase = len(fields)
yoffset = ybase / 2
ylabels = df.index.get_level_values('year')
ax.yaxis.set_major_locator(IndexLocator(base=ybase, offset=yoffset))
ax.yaxis.set_major_formatter(IndexFormatter(ylabels))
plt.ylabel('Year')

plt.show()

这正是我想要的:

但这就是问题所在。我的实际 DataFrame 有 15 年、4,000 个字段、365 天和 7 个波段。如果我真的每天都贴标签,标签将难以辨认。我可以每 50 天放置一个刻度,但我希望刻度是动态的,这样当我放大时,刻度会变得更细粒度。基本上我正在寻找的是一个自定义 MultiIndexLocator,它结合了 IndexLocator 的位置和 MaxNLocator.

的活力

奖励:从某种意义上说,我的数据非常好,每年总是有相同数量的字段,每天都有相同数量的波段。但如果情况并非如此呢?我很乐意为 matplotlib 贡献一个适用于任何 MultiIndex DataFrame 的通用 MultiIndexLocatorMultiIndexFormatter

Matplotlib 不了解数据帧或多索引。它只是绘制您提供的数据。 IE。你得到的结果就好像你在绘制 numpy 数据数组一样,spy(df.values).

所以我建议首先正确设置图像的范围,以便您可以使用数字代码。那么 MaxNLocator 应该可以正常工作,除非你没有放大太多。

import numpy as np
import pandas as pd
from matplotlib.ticker import MaxNLocator
import matplotlib.pyplot as plt
plt.rcParams['axes.formatter.useoffset'] = False

years = range(2000, 2018)
fields = range(9) #17
days = range(120) #365
bands = ['R', 'G', 'B', 'A']

index = pd.MultiIndex.from_product(
    [years, fields], names=['year', 'field'])
columns = pd.MultiIndex.from_product(
    [days, bands], names=['day', 'band'])

data = np.random.rand(len(years)*len(fields),len(days)*len(bands))
x,y = np.meshgrid(np.arange(data.shape[1]),np.arange(data.shape[0]))
data += 2*((y//len(fields)+x//len(bands)) % 2)
df = pd.DataFrame(data, index=index, columns=columns)

############
# Plotting
############

xbase = len(bands)
xlabels = df.columns.get_level_values('day')
ybase = len(fields)
ylabels = df.index.get_level_values('year')

extent = [xlabels.min()-np.diff(np.unique(xlabels))[0]/2.,
          xlabels.max()+np.diff(np.unique(xlabels))[0]/2.,
          ylabels.min()-np.diff(np.unique(ylabels))[0]/2.,
          ylabels.max()+np.diff(np.unique(ylabels))[0]/2.,]

fig, ax = plt.subplots()

ax.imshow(df.values, extent=extent, aspect="auto")
ax.set_ylabel('Year')
ax.set_xlabel('Day')

ax.xaxis.set_major_locator(MaxNLocator(integer=True,min_n_ticks=1))
ax.yaxis.set_major_locator(MaxNLocator(integer=True,min_n_ticks=1))


plt.show()