pyarrow Table 过滤 -- huggingface

pyarrow Table Filtering -- huggingface

我正在尝试根据列表中的 ID 过滤数据集。这种方法太慢了。数据集是 Arrow 数据集。 从 huggingface 导入数据。

import numpy as np
from datasets import load_dataset, DatasetDict
from collections import Counter
import pyarrow as pa
import pandas as pd


responses = load_dataset('peixian/rtGender', 'responses', split = 'train')
# post_id_test_list contains list of ids
responses_test = responses.filter(lambda x: x['post_id'] in post_id_test_list)

您从 load_dataset 获得的数据集不是 arrow Dataset but a hugging face Dataset。不过它有一个箭头 table 支持。

应用 lambda 过滤器会很慢,如果你想要更快的 vertorized 操作,你可以尝试直接修改底层箭头 Table:

import pyarrow as pa
import pyarrow.compute as compute


table = responses.data

flags = compute.is_in(table['post_id'], value_set=pa.array(post_id_test_list, pa.int32()))
filtered_table = table.filter(flags)

filtered_respoonse = datasets.DataSet(filtered_table, response.info, response.split)

虽然我不能 100% 确定最后一行是否是使用箭头重新创建数据集的正确方法 table。

这几乎让我到达那里。如前所述,最后一行不起作用,但我可以转换为 pandas / 保存等。谢谢!

import pyarrow as pa
import pyarrow.compute as compute


table = responses.data

flags = compute.is_in(table['post_id'], value_set=pa.array(post_id_test_list, pa.int32()))
filtered_table = table.filter(flags)
filtered_table.to_pandas()