scipy griddata 在样本之间产生 nan 值

scipy griddata produces nan values between samples

我正在尝试根据非结构化样本插入网格点。我的样本取自 0.01 到 10(x 轴)和 1e-8 到 1(y 轴)之间的对数 space。当我 运行 此代码时:

from scipy.interpolate import griddata

data = pd.read_csv('data.csv')

param1, param2, errors = data['param1'].values, data['param2'].values, data['error'].values

x = np.linspace(param1.min(), param1.max(), 100, endpoint=True)
y = np.linspace(param2.min(), param2.max(), 100, endpoint=True)

X, Y = np.meshgrid(x, y)

Z = griddata((param1, param2), errors, (X, Y), method='linear')

fig, ax = plt.subplots(figsize=(10, 7))

cax = ax.contourf(X, Y, Z, 25, cmap='hot')
ax.scatter(param1, param2, s=1, color='black', alpha=0.4)
ax.set(xscale='log', yscale='log')

cbar = fig.colorbar(cax)
fig.tight_layout()

我明白了 result.The 白色区域显示 NaN 值。 x 和 y 轴均为对数刻度:

即使白色区域中有样本(散点证明),griddata 也会产生 NaN。 数据中没有 NaNs/infs。 我是漏掉了什么还是只是 Scipy 中的错误?

data.csv

这是由于 X-Y 插值网格的线性间距和轴的对数缩放。这很容易通过几何(“对数”)间隔插值网格来解决。

也可以插值log-space; IMO 这给出了更好看的结果,但它可能无效。

这是你的图的 more-coarsely-sampled 版本,显示了插值网格点如何“聚集”到 log-scaled 图中的右上角。这里轴的顶行显示数据是有限的,底行是“真实”图:

您可以看到 linearly-spaced 示例网格的最左端和最底端点(恰好!)在一组值之外;这尤其糟糕,因为由于对数缩放,下一个最近的点线在视觉上很远。

这是几何插值网格 spaced 的结果,插值也在 space.

中完成

您可以运行下面的代码查看其他两个变体。

from itertools import product

import numpy as np
import matplotlib.pyplot as plt
from scipy.interpolate import griddata
import pandas as pd

CMAP = None
# crude, to make interpolation grid visible
NX = 11
NY = 11

def plot_general(log_grid=False, log_interp=False):
    data = pd.read_csv('data.csv')

    param1, param2, errors = data['param1'].values, data['param2'].values, data['error'].values

    if log_grid:
        x = np.geomspace(param1.min(), param1.max(), NX)
        y = np.geomspace(param2.min(), param2.max(), NY)
    else:
        x = np.linspace(param1.min(), param1.max(), NX)
        y = np.linspace(param2.min(), param2.max(), NY)

    X, Y = np.meshgrid(x, y)

    if log_interp:
        Z = griddata((np.log10(param1), np.log10(param2)), errors, (np.log10(X), np.log10(Y)), method='linear')
    else:
        Z = griddata((param1, param2), errors, (X, Y), method='linear')

    fZ = np.isfinite(Z)

    fig, ax = plt.subplots(2, 2)

    ax[0,0].contourf(X, Y, fZ, levels=[0.5,1.5])
    ax[0,0].scatter(param1, param2, s=1, color='black')
    ax[0,0].plot(X.flat, Y.flat, '.', color='blue')

    ax[0,1].contourf(X, Y, fZ, levels=[0.5,1.5])
    ax[0,1].scatter(param1, param2, s=1, color='black')
    ax[0,1].plot(X.flat, Y.flat, '.', color='blue')
    ax[0,1].set(xscale='log', yscale='log')

    ax[1,0].contourf(X, Y, Z, levels=25, cmap=CMAP)
    ax[1,0].scatter(param1, param2, s=1, color='black')
    ax[1,0].plot(X.flat, Y.flat, '.', color='blue')
    ax[1,1].contourf(X, Y, Z, levels=25, cmap=CMAP)
    ax[1,1].scatter(param1, param2, s=1, color='black')
    ax[1,1].set(xscale='log', yscale='log')
    ax[1,1].plot(X.flat, Y.flat, '.', color='blue')

    fig.suptitle(f'{log_grid=}, {log_interp=}')
    fig.tight_layout()
    return fig

plt.close('all')

for log_grid, log_interp in product([False, True],
                                    [False, True]):
    fig = plot_general(log_grid, log_interp)
    #if you want to save results:
    #fig.savefig(f'log_grid{log_grid}-log_interp{log_interp}.png')