如何使 Matplotlib 散点图作为一个整体透明?

How to make Matplotlib scatterplots transparent as a group?

我正在使用 Matplotlib(python 3.4.0、matplotlib 1.4.3、运行 on Linux Mint 17)制作一些散点图。为每个点单独设置 alpha 透明度非常容易;有什么办法可以将它们设置为一组,使同一组中的两个重叠点不会改变颜色?

示例代码:

import matplotlib.pyplot as plt
import numpy as np

def points(n=100):
    x = np.random.uniform(size=n)
    y = np.random.uniform(size=n)
    return x, y
x1, y1 = points()
x2, y2 = points()
fig = plt.figure(figsize=(4,4))
ax = fig.add_subplot(111, title="Test scatter")
ax.scatter(x1, y1, s=100, color="blue", alpha=0.5)
ax.scatter(x2, y2, s=100, color="red", alpha=0.5)
fig.savefig("test_scatter.png")

此输出结果:

但我想要更像这样的东西:

我可以通过保存为 SVG 并在 Inkscape 中手动分组然后设置透明度来解决,但我真的更喜欢我可以编码的东西。有什么建议吗?

有趣的问题,我认为任何透明度的使用都会导致您想要避免的堆叠效果。您可以手动设置透明类型颜色以更接近您想要的结果,

import matplotlib.pyplot as plt
import numpy as np

def points(n=100):
    x = np.random.uniform(size=n)
    y = np.random.uniform(size=n)
    return x, y
x1, y1 = points()
x2, y2 = points()
fig = plt.figure(figsize=(4,4))
ax = fig.add_subplot(111, title="Test scatter")
alpha = 0.5
ax.scatter(x1, y1, s=100, lw = 0, color=[1., alpha, alpha])
ax.scatter(x2, y2, s=100, lw = 0, color=[alpha, alpha, 1.])
plt.show()

这种方式不包括不同颜色之间的重叠,但你得到,

是的,有趣的问题。你可以用 Shapely 得到这个散点图。这是代码:

import matplotlib.pyplot as plt
import matplotlib.patches as ptc
import numpy as np
from shapely.geometry import Point
from shapely.ops import cascaded_union

n = 100
size = 0.02
alpha = 0.5

def points():
    x = np.random.uniform(size=n)
    y = np.random.uniform(size=n)
    return x, y

x1, y1 = points()
x2, y2 = points()
polygons1 = [Point(x1[i], y1[i]).buffer(size) for i in range(n)]
polygons2 = [Point(x2[i], y2[i]).buffer(size) for i in range(n)]
polygons1 = cascaded_union(polygons1)
polygons2 = cascaded_union(polygons2)

fig = plt.figure(figsize=(4,4))
ax = fig.add_subplot(111, title="Test scatter")
for polygon1 in polygons1:
    polygon1 = ptc.Polygon(np.array(polygon1.exterior), facecolor="red", lw=0, alpha=alpha)
    ax.add_patch(polygon1)
for polygon2 in polygons2:
    polygon2 = ptc.Polygon(np.array(polygon2.exterior), facecolor="blue", lw=0, alpha=alpha)
    ax.add_patch(polygon2)
ax.axis([-0.2, 1.2, -0.2, 1.2])

fig.savefig("test_scatter.png")

结果是:

这是一个非常糟糕的破解方法,但它确实有效。

您会看到,当 Matplotlib 将数据点绘制为可以重叠的单独对象时,它将它们之间的线绘制为单个对象 - 即使该线被数据中的 NaN 分成几部分。

考虑到这一点,您可以这样做:

import numpy as np
from matplotlib import pyplot as plt

plt.rcParams['lines.solid_capstyle'] = 'round'

def expand(x, y, gap=1e-4):
    add = np.tile([0, gap, np.nan], len(x))
    x1 = np.repeat(x, 3) + add
    y1 = np.repeat(y, 3) + add
    return x1, y1

x1, y1 = points()
x2, y2 = points()
fig = plt.figure(figsize=(4,4))
ax = fig.add_subplot(111, title="Test scatter")
ax.plot(*expand(x1, y1), lw=20, color="blue", alpha=0.5)
ax.plot(*expand(x2, y2), lw=20, color="red", alpha=0.5)

fig.savefig("test_scatter.png")
plt.show()

而且每种颜色都会与另一种颜色重叠,但不会与自身重叠。

需要注意的是,您必须注意用于制作每个圆的两点之间的间距。如果它们两个相距很远,那么分离将在您的绘图上可见,但如果它们靠得太近,matplotlib 根本不会绘制线条。这意味着需要根据数据范围选择分离,如果您打算制作交互式绘图,那么如果缩小太多,所有数据点都可能突然消失,如果放大则拉伸太多了。

如您所见,我发现 1e-5 可以很好地分离范围为 [0,1] 的数据。

只需将参数 edgecolors='none' 传递给 plt.scatter()

如果您要绘制的不仅仅是几个点,这里有一个 hack。我必须绘制 >500000 个点,并且 shapely 解决方案不能很好地扩展。我还想绘制圆形以外的其他形状。我选择使用 alpha=1 分别绘制每一层,然后使用 np.frombuffer 读取结果图像(如 所述),然后将 alpha 添加到整个图像并使用 plt.imshow。请注意,此解决方案会丧失对原始 fig 对象和属性的访问权限,因此应在绘制之前对图形进行任何其他修改。

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
from matplotlib.figure import Figure

def arr_from_fig(fig):
    canvas = FigureCanvas(fig)
    canvas.draw()
    img = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
    img = img.reshape(fig.canvas.get_width_height()[::-1] + (3,))
    return img

def points(n=100):
    x = np.random.uniform(size=n)
    y = np.random.uniform(size=n)
    return x, y

x1, y1 = points()
x2, y2 = points()
imgs = list()
figsize = (4, 4)
dpi = 200

for x, y, c in zip([x1, x2], [y1, y2], ['blue', 'red']):
    fig = plt.figure(figsize=figsize, dpi=dpi, tight_layout={'pad':0})
    ax = fig.add_subplot(111)
    ax.scatter(x, y, s=100, color=c, alpha=1)
    ax.axis([-0.2, 1.2, -0.2, 1.2])
    ax.axis('off')
    imgs.append(arr_from_fig(fig))
    plt.close()


fig = plt.figure(figsize=figsize)
alpha = 0.5

alpha_scaled = 255*alpha
for img in imgs:
    img_alpha = np.where((img == 255).all(-1), 0, alpha_scaled).reshape([*img.shape[:2], 1])
    img_show = np.concatenate([img, img_alpha], axis=-1).astype(int)
    plt.imshow(img_show, origin='lower')

ticklabels = ['{:03.1f}'.format(i) for i in np.linspace(-0.2, 1.2, 8, dtype=np.float16)]
plt.xticks(ticks=np.linspace(0, dpi*figsize[0], 8), labels=ticklabels)
plt.yticks(ticks=np.linspace(0, dpi*figsize[1], 8), labels=ticklabels);
plt.title('Test scatter');