大型 3D 阵列上的快速一维线性 np.NaN 插值
Fast 1D linear np.NaN interpolation over large 3D array
我有一个 (z, y, x)
和 shape=(92, 4800, 4800)
的 3D 数组,其中 axis 0
上的每个值代表不同的时间点。在少数情况下,时域中的值获取失败,导致某些值为 np.NaN
。在其他情况下,没有获取任何值,z
沿线的所有值都是 np.NaN
.
使用线性插值沿着 axis 0
填充 np.NaN
的最有效方法是什么,忽略所有值都是 np.NaN
的情况?
这是我正在做的工作示例,它使用 pandas
包装器到 scipy.interpolate.interp1d
。这在原始数据集上每个切片大约需要 2 秒,这意味着整个数组在 2.6 小时内处理完毕。缩小尺寸的示例数据集大约需要 9.5 秒。
import numpy as np
import pandas as pd
# create example data, original is (92, 4800, 4800)
test_arr = np.random.randint(low=-10000, high=10000, size=(92, 480, 480))
test_arr[1:90:7, :, :] = -32768 # NaN fill value in original data
test_arr[:, 1:90:6, 1:90:8] = -32768
def interpolate_nan(arr, method="linear", limit=3):
"""return array interpolated along time-axis to fill missing values"""
result = np.zeros_like(arr, dtype=np.int16)
for i in range(arr.shape[1]):
# slice along y axis, interpolate with pandas wrapper to interp1d
line_stack = pd.DataFrame(data=arr[:,i,:], dtype=np.float32)
line_stack.replace(to_replace=-37268, value=np.NaN, inplace=True)
line_stack.interpolate(method=method, axis=0, inplace=True, limit=limit)
line_stack.replace(to_replace=np.NaN, value=-37268, inplace=True)
result[:, i, :] = line_stack.values.astype(np.int16)
return result
使用示例数据集在我的机器上的性能:
%timeit interpolate_nan(test_arr)
1 loops, best of 3: 9.51 s per loop
编辑:
我应该澄清一下代码正在产生我预期的结果。问题是 - 我怎样才能优化这个过程?
这取决于;你将不得不拿出一张 sheet 的纸并计算如果你 不 插值并仅将这些 填零,你的总体统计数据将得到的误差NaN.
除此之外,我认为你的插值有点过头了。
只需找到每个 NaN,并线性插值到相邻的四个值(即,将 (y +- 1,x +- 1) 处的值相加)——这将严重限制您的错误(自行计算!),并且您没有使用您的案例中使用的任何复杂方法进行插值(您没有定义 method
)。
您可以尝试为每个 z 值预先计算一个 "averaged" 4800x4800 矩阵——这应该不会花很长时间——通过在矩阵上应用十字形内核(这一切都非常形象-processing-like,这里)。在 NaN 的情况下,一些平均值将是 NaN(NaN 在邻域中的每个平均像素),但您不关心 - 除非有两个相邻的 NaN,即您想要替换的 NaN 单元格原矩阵均为实值
然后您只需将所有 NaN 替换为平均矩阵中的值。
将其速度与 "manual" 计算您找到的每个 NaN 的邻域平均值的速度进行比较。
我最近在 numba 的帮助下为我的特定用例解决了这个问题,并且 a little writeup on it.
from numba import jit
@jit(nopython=True)
def interpolate_numba(arr, no_data=-32768):
"""return array interpolated along time-axis to fill missing values"""
result = np.zeros_like(arr, dtype=np.int16)
for x in range(arr.shape[2]):
# slice along x axis
for y in range(arr.shape[1]):
# slice along y axis
for z in range(arr.shape[0]):
value = arr[z,y,x]
if z == 0: # don't interpolate first value
new_value = value
elif z == len(arr[:,0,0])-1: # don't interpolate last value
new_value = value
elif value == no_data: # interpolate
left = arr[z-1,y,x]
right = arr[z+1,y,x]
# look for valid neighbours
if left != no_data and right != no_data: # left and right are valid
new_value = (left + right) / 2
elif left == no_data and z == 1: # boundary condition left
new_value = value
elif right == no_data and z == len(arr[:,0,0])-2: # boundary condition right
new_value = value
elif left == no_data and right != no_data: # take second neighbour to the left
more_left = arr[z-2,y,x]
if more_left == no_data:
new_value = value
else:
new_value = (more_left + right) / 2
elif left != no_data and right == no_data: # take second neighbour to the right
more_right = arr[z+2,y,x]
if more_right == no_data:
new_value = value
else:
new_value = (more_right + left) / 2
elif left == no_data and right == no_data: # take second neighbour on both sides
more_left = arr[z-2,y,x]
more_right = arr[z+2,y,x]
if more_left != no_data and more_right != no_data:
new_value = (more_left + more_right) / 2
else:
new_value = value
else:
new_value = value
else:
new_value = value
result[z,y,x] = int(new_value)
return result
这比我的初始代码快 20 倍。
提问者借numba
的优势给出了很好的回答。我真的很感激,但我不能完全同意 interpolate_numba
函数中的内容。我不认为对特定点进行线性插值的逻辑是找到其左右邻居的平均值。为了说明,假设我们有一个数组 [1,nan,nan,4,nan,6],上面的 interpolate_numba
函数可能 return [1,2.5,2.5,4,5,6 ](只是理论上的推论),而 pandas
wrapper 肯定会 return [1,2,3,4,5,6]。相反,我认为对特定点进行线性插值的逻辑是找到它的左右邻居,使用它们的值来确定一条线(即斜率和截距),最后计算插值值。下面显示了我的代码。为了简单起见,我假设输入数据是一个包含 nan 值的 3-D 数组。我规定第一个和最后一个元素等效于它们的左右最近邻居(即 pandas
中的 limit_direction='both'
)。我没有指定连续插值的最大数量(即pandas
中没有limit
)。
import numpy as np
from numba import jit
@jit(nopython=True)
def f(arr_3d):
result=np.zeros_like(arr_3d)
for i in range(arr_3d.shape[1]):
for j in range(arr_3d.shape[2]):
arr=arr_3d[:,i,j]
# If all elements are nan then cannot conduct linear interpolation.
if np.sum(np.isnan(arr))==arr.shape[0]:
result[:,i,j]=arr
else:
# If the first elemet is nan, then assign the value of its right nearest neighbor to it.
if np.isnan(arr[0]):
arr[0]=arr[~np.isnan(arr)][0]
# If the last element is nan, then assign the value of its left nearest neighbor to it.
if np.isnan(arr[-1]):
arr[-1]=arr[~np.isnan(arr)][-1]
# If the element is in the middle and its value is nan, do linear interpolation using neighbor values.
for k in range(arr.shape[0]):
if np.isnan(arr[k]):
x=k
x1=x-1
x2=x+1
# Find left neighbor whose value is not nan.
while x1>=0:
if np.isnan(arr[x1]):
x1=x1-1
else:
y1=arr[x1]
break
# Find right neighbor whose value is not nan.
while x2<arr.shape[0]:
if np.isnan(arr[x2]):
x2=x2+1
else:
y2=arr[x2]
break
# Calculate the slope and intercept determined by the left and right neighbors.
slope=(y2-y1)/(x2-x1)
intercept=y1-slope*x1
# Linear interpolation and assignment.
y=slope*x+intercept
arr[x]=y
result[:,i,j]=arr
return result
初始化一个包含一些 nans 的 3-D 数组,我检查了我的代码,它可以给出与 pandas
包装器相同的答案。通过 pandas
包装器代码会有点混乱,因为 pandas 只能处理二维数据。
使用我的代码
y1=np.ones((2,2))
y2=y1+1
y3=y2+np.nan
y4=y2+2
y5=y1+np.nan
y6=y4+2
y1[1,1]=np.nan
y2[0,0]=np.nan
y4[1,1]=np.nan
y6[1,1]=np.nan
y=np.stack((y1,y2,y3,y4,y5,y6),axis=0)
print(y)
print("="*10)
f(y)
使用 pandas 包装器
import pandas as pd
y1=np.ones((2,2)).flatten()
y2=y1+1
y3=y2+np.nan
y4=y2+2
y5=y1+np.nan
y6=y4+2
y1[3]=np.nan
y2[0]=np.nan
y4[3]=np.nan
y6[3]=np.nan
y=pd.DataFrame(np.stack([y1,y2,y3,y4,y5,y6],axis=0))
y=y.interpolate(method='linear', limit_direction='both', axis=0)
y_numpy=y.to_numpy()
y_numpy.shape=((6,2,2))
print(np.stack([y1,y2,y3,y4,y5,y6],axis=0).reshape(6,2,2))
print("="*10)
print(y_numpy)
输出将相同
[[[ 1. 1.]
[ 1. nan]]
[[nan 2.]
[ 2. 2.]]
[[nan nan]
[nan nan]]
[[ 4. 4.]
[ 4. nan]]
[[nan nan]
[nan nan]]
[[ 6. 6.]
[ 6. nan]]]
==========
[[[1. 1.]
[1. 2.]]
[[2. 2.]
[2. 2.]]
[[3. 3.]
[3. 2.]]
[[4. 4.]
[4. 2.]]
[[5. 5.]
[5. 2.]]
[[6. 6.]
[6. 2.]]]
使用test_arr
数据增加到(92,4800,4800)作为输入,我发现只需要大约40秒就可以完成插值!
test_arr = np.random.randint(low=-10000, high=10000, size=(92, 4800, 4800))
test_arr[1:90:7, :, :] = np.nan # NaN fill value in original data
test_arr[2,:,:] = np.nan
test_arr[:, 1:479:6, 1:479:8] = np.nan
%time f(test_arr)
输出
CPU times: user 32.5 s, sys: 9.13 s, total: 41.6 s
Wall time: 41.6 s
我有一个 (z, y, x)
和 shape=(92, 4800, 4800)
的 3D 数组,其中 axis 0
上的每个值代表不同的时间点。在少数情况下,时域中的值获取失败,导致某些值为 np.NaN
。在其他情况下,没有获取任何值,z
沿线的所有值都是 np.NaN
.
使用线性插值沿着 axis 0
填充 np.NaN
的最有效方法是什么,忽略所有值都是 np.NaN
的情况?
这是我正在做的工作示例,它使用 pandas
包装器到 scipy.interpolate.interp1d
。这在原始数据集上每个切片大约需要 2 秒,这意味着整个数组在 2.6 小时内处理完毕。缩小尺寸的示例数据集大约需要 9.5 秒。
import numpy as np
import pandas as pd
# create example data, original is (92, 4800, 4800)
test_arr = np.random.randint(low=-10000, high=10000, size=(92, 480, 480))
test_arr[1:90:7, :, :] = -32768 # NaN fill value in original data
test_arr[:, 1:90:6, 1:90:8] = -32768
def interpolate_nan(arr, method="linear", limit=3):
"""return array interpolated along time-axis to fill missing values"""
result = np.zeros_like(arr, dtype=np.int16)
for i in range(arr.shape[1]):
# slice along y axis, interpolate with pandas wrapper to interp1d
line_stack = pd.DataFrame(data=arr[:,i,:], dtype=np.float32)
line_stack.replace(to_replace=-37268, value=np.NaN, inplace=True)
line_stack.interpolate(method=method, axis=0, inplace=True, limit=limit)
line_stack.replace(to_replace=np.NaN, value=-37268, inplace=True)
result[:, i, :] = line_stack.values.astype(np.int16)
return result
使用示例数据集在我的机器上的性能:
%timeit interpolate_nan(test_arr)
1 loops, best of 3: 9.51 s per loop
编辑:
我应该澄清一下代码正在产生我预期的结果。问题是 - 我怎样才能优化这个过程?
这取决于;你将不得不拿出一张 sheet 的纸并计算如果你 不 插值并仅将这些 填零,你的总体统计数据将得到的误差NaN.
除此之外,我认为你的插值有点过头了。
只需找到每个 NaN,并线性插值到相邻的四个值(即,将 (y +- 1,x +- 1) 处的值相加)——这将严重限制您的错误(自行计算!),并且您没有使用您的案例中使用的任何复杂方法进行插值(您没有定义 method
)。
您可以尝试为每个 z 值预先计算一个 "averaged" 4800x4800 矩阵——这应该不会花很长时间——通过在矩阵上应用十字形内核(这一切都非常形象-processing-like,这里)。在 NaN 的情况下,一些平均值将是 NaN(NaN 在邻域中的每个平均像素),但您不关心 - 除非有两个相邻的 NaN,即您想要替换的 NaN 单元格原矩阵均为实值
然后您只需将所有 NaN 替换为平均矩阵中的值。
将其速度与 "manual" 计算您找到的每个 NaN 的邻域平均值的速度进行比较。
我最近在 numba 的帮助下为我的特定用例解决了这个问题,并且 a little writeup on it.
from numba import jit
@jit(nopython=True)
def interpolate_numba(arr, no_data=-32768):
"""return array interpolated along time-axis to fill missing values"""
result = np.zeros_like(arr, dtype=np.int16)
for x in range(arr.shape[2]):
# slice along x axis
for y in range(arr.shape[1]):
# slice along y axis
for z in range(arr.shape[0]):
value = arr[z,y,x]
if z == 0: # don't interpolate first value
new_value = value
elif z == len(arr[:,0,0])-1: # don't interpolate last value
new_value = value
elif value == no_data: # interpolate
left = arr[z-1,y,x]
right = arr[z+1,y,x]
# look for valid neighbours
if left != no_data and right != no_data: # left and right are valid
new_value = (left + right) / 2
elif left == no_data and z == 1: # boundary condition left
new_value = value
elif right == no_data and z == len(arr[:,0,0])-2: # boundary condition right
new_value = value
elif left == no_data and right != no_data: # take second neighbour to the left
more_left = arr[z-2,y,x]
if more_left == no_data:
new_value = value
else:
new_value = (more_left + right) / 2
elif left != no_data and right == no_data: # take second neighbour to the right
more_right = arr[z+2,y,x]
if more_right == no_data:
new_value = value
else:
new_value = (more_right + left) / 2
elif left == no_data and right == no_data: # take second neighbour on both sides
more_left = arr[z-2,y,x]
more_right = arr[z+2,y,x]
if more_left != no_data and more_right != no_data:
new_value = (more_left + more_right) / 2
else:
new_value = value
else:
new_value = value
else:
new_value = value
result[z,y,x] = int(new_value)
return result
这比我的初始代码快 20 倍。
提问者借numba
的优势给出了很好的回答。我真的很感激,但我不能完全同意 interpolate_numba
函数中的内容。我不认为对特定点进行线性插值的逻辑是找到其左右邻居的平均值。为了说明,假设我们有一个数组 [1,nan,nan,4,nan,6],上面的 interpolate_numba
函数可能 return [1,2.5,2.5,4,5,6 ](只是理论上的推论),而 pandas
wrapper 肯定会 return [1,2,3,4,5,6]。相反,我认为对特定点进行线性插值的逻辑是找到它的左右邻居,使用它们的值来确定一条线(即斜率和截距),最后计算插值值。下面显示了我的代码。为了简单起见,我假设输入数据是一个包含 nan 值的 3-D 数组。我规定第一个和最后一个元素等效于它们的左右最近邻居(即 pandas
中的 limit_direction='both'
)。我没有指定连续插值的最大数量(即pandas
中没有limit
)。
import numpy as np
from numba import jit
@jit(nopython=True)
def f(arr_3d):
result=np.zeros_like(arr_3d)
for i in range(arr_3d.shape[1]):
for j in range(arr_3d.shape[2]):
arr=arr_3d[:,i,j]
# If all elements are nan then cannot conduct linear interpolation.
if np.sum(np.isnan(arr))==arr.shape[0]:
result[:,i,j]=arr
else:
# If the first elemet is nan, then assign the value of its right nearest neighbor to it.
if np.isnan(arr[0]):
arr[0]=arr[~np.isnan(arr)][0]
# If the last element is nan, then assign the value of its left nearest neighbor to it.
if np.isnan(arr[-1]):
arr[-1]=arr[~np.isnan(arr)][-1]
# If the element is in the middle and its value is nan, do linear interpolation using neighbor values.
for k in range(arr.shape[0]):
if np.isnan(arr[k]):
x=k
x1=x-1
x2=x+1
# Find left neighbor whose value is not nan.
while x1>=0:
if np.isnan(arr[x1]):
x1=x1-1
else:
y1=arr[x1]
break
# Find right neighbor whose value is not nan.
while x2<arr.shape[0]:
if np.isnan(arr[x2]):
x2=x2+1
else:
y2=arr[x2]
break
# Calculate the slope and intercept determined by the left and right neighbors.
slope=(y2-y1)/(x2-x1)
intercept=y1-slope*x1
# Linear interpolation and assignment.
y=slope*x+intercept
arr[x]=y
result[:,i,j]=arr
return result
初始化一个包含一些 nans 的 3-D 数组,我检查了我的代码,它可以给出与 pandas
包装器相同的答案。通过 pandas
包装器代码会有点混乱,因为 pandas 只能处理二维数据。
使用我的代码
y1=np.ones((2,2))
y2=y1+1
y3=y2+np.nan
y4=y2+2
y5=y1+np.nan
y6=y4+2
y1[1,1]=np.nan
y2[0,0]=np.nan
y4[1,1]=np.nan
y6[1,1]=np.nan
y=np.stack((y1,y2,y3,y4,y5,y6),axis=0)
print(y)
print("="*10)
f(y)
使用 pandas 包装器
import pandas as pd
y1=np.ones((2,2)).flatten()
y2=y1+1
y3=y2+np.nan
y4=y2+2
y5=y1+np.nan
y6=y4+2
y1[3]=np.nan
y2[0]=np.nan
y4[3]=np.nan
y6[3]=np.nan
y=pd.DataFrame(np.stack([y1,y2,y3,y4,y5,y6],axis=0))
y=y.interpolate(method='linear', limit_direction='both', axis=0)
y_numpy=y.to_numpy()
y_numpy.shape=((6,2,2))
print(np.stack([y1,y2,y3,y4,y5,y6],axis=0).reshape(6,2,2))
print("="*10)
print(y_numpy)
输出将相同
[[[ 1. 1.]
[ 1. nan]]
[[nan 2.]
[ 2. 2.]]
[[nan nan]
[nan nan]]
[[ 4. 4.]
[ 4. nan]]
[[nan nan]
[nan nan]]
[[ 6. 6.]
[ 6. nan]]]
==========
[[[1. 1.]
[1. 2.]]
[[2. 2.]
[2. 2.]]
[[3. 3.]
[3. 2.]]
[[4. 4.]
[4. 2.]]
[[5. 5.]
[5. 2.]]
[[6. 6.]
[6. 2.]]]
使用test_arr
数据增加到(92,4800,4800)作为输入,我发现只需要大约40秒就可以完成插值!
test_arr = np.random.randint(low=-10000, high=10000, size=(92, 4800, 4800))
test_arr[1:90:7, :, :] = np.nan # NaN fill value in original data
test_arr[2,:,:] = np.nan
test_arr[:, 1:479:6, 1:479:8] = np.nan
%time f(test_arr)
输出
CPU times: user 32.5 s, sys: 9.13 s, total: 41.6 s
Wall time: 41.6 s