通过循环有效地将数据添加到 h5py 数据集

Efficiently add data to h5py dataset over a loop

我有 torch.tensors 需要保存到磁盘,因为它们是大文件,会占用所有内存。

我是 h5py 的新手,我不知道如何有效地制作数据集。这个过程很慢。

下面是一个非常 MWE,我打算将其转换成一个循环。

import numpy as np
import h5py

data = np.random.random((13, 8, 512, 768))

f = h5py.File('C:\Users\Andrew\Desktop\test_h5\xd.h5', 'w')
dset = f.create_dataset('embeds', shape=(13, 8, 512, 768),
                        maxshape=(None, 8, 512, 768), chunks=(13, 8, 512, 768),
                        dtype=np.float16)

# add first chunk of rows
dset[:] = data[0:13, :, :,]

# Resize the dataset to accommodate the next chunk of rows
dset.resize(26, axis=0)

# Write the next chunk
dset[13:] = np.random.random((13, 8, 512, 768))

# check data
with h5py.File('C:\Users\Andrew\Desktop\test_h5\xd.h5', 'r') as f:
    print(f['embeds'][0:26].shape)
    print(f['embeds'][0:26])
f.close()

编辑:

我在弄清楚如何确保最后附加的数据实际上是最后生成的数据时没有问题,请考虑以下几点:

import numpy as np
import h5py

data = np.random.random((13, 8, 512, 768)).astype(np.float32)

batch_size=8
with h5py.File('SO_65606675.h5', 'w') as f:
    # create empty data set
    dset = f.create_dataset('embeds', shape=(13, 16, 512, 768),
                            maxshape=(13, None, 512, 768), chunks=(13, 8, 512, 768),
                            dtype=np.float32)
    for cnt in range(2):
        # add chunk of rows
        start = cnt*batch_size
        dset[:, start:start+batch_size, :, :] = data[:, :, :, :]

        # Create attribute with last_index value
        dset.attrs['last_index']=(cnt+1)*batch_size


# check data
with h5py.File('SO_65606675.h5', 'r') as f:
    print(f['embeds'].attrs['last_index'])
    print(f['embeds'].shape)
    x = f['embeds'][:, 8:16, :, :]  # get last entry
np.array_equal(x, data)  # passes

Edit2 :我想我在上面有一个错误并且这个有效;将检查我的“真实”数据。

这是一个简单的例子,它综合了我的建议,展示了在你的情况下一切可能如何运作。程序流程摘要:

  1. 打开一个新文件,创建数据集'embeds'shape=(130, 8, 512, 768),然后添加2组数据,写入'last_index'属性然后关闭文件。
  2. 重新以APPEND模式打开文件,访问数据集'embeds',添加更多2 数据集(从 'last_index' 开始),写入 'last_index' 属性并关闭文件。
  3. 上次打开的文件处于读取模式以打印数据集属性和形状 参数.

备注:

  • 我使用 HDFView 直观地验证数据集内容。我发现查看 np.float16 有问题,所以我使用了 np.float32。这应该适用于 np.float16。我会让你验证一下。
  • 您还应该添加标准完整性检查和错误处理。例如:1) 'embeds' 数据集和 'last_index' 属性都存在,2) 检查数据集大小以确认数据适合,以及 3) 如果您的新数据超出当前数据,则调整大小界限。

代码如下:

import numpy as np
import h5py

data = np.random.random((13, 8, 512, 768)).astype(np.float32)

with h5py.File('SO_65606675.h5', 'w') as f:
    # create empty data set
    dset = f.create_dataset('embeds', shape=(130, 8, 512, 768),
                            maxshape=(None, 8, 512, 768), chunks=(13, 8, 512, 768),
                            dtype=np.float32)
    for cnt in range(2):
        # add chunk of rows
        start = cnt*13
        dset[start:start+13, :, :, :] = data[:, :, :, :]
        
        # Create attribute with last_index value
        dset.attrs['last_index']=(cnt+1)*13

# add more data
with h5py.File('SO_65606675.h5', 'a') as f: # USE APPEND MODE
    dset = f['embeds']
    for cnt in range(2):
        start = dset.attrs['last_index']
        # add chunk of rows
        dset[start:start+13, :, :, :] = data[:, :, :, :]
    
        # Resize the dataset to accommodate the next chunk of rows
        #dset.resize(26, axis=0)
        
        # Create attribute with last_index value
        dset.attrs['last_index']=start+(cnt+1)*13

# check data
with h5py.File('SO_65606675.h5', 'r') as f:
    print(f['embeds'].attrs['last_index'])
    print(f['embeds'].shape)
    #print(f['embeds'][0:26])