在 pyarrow table 中删除重复项?

Dropping duplicates in a pyarrow table?

有没有一种方法可以使用纯 pyarrow tables 对数据进行排序并删除重复项?我的目标是根据最大更新时间戳检索每个 ID 的最新版本。

一些额外的细节:我的数据集通常至少分为两个版本:

历史数据集将包括来源中的​​所有更新项目,因此对于发生在它身上的每个更改,单个 ID 可能有重复项(例如,想象一张 Zendesk 或 ServiceNow 票证,票证可以是多次更新)

然后我使用过滤器读取历史数据集,将其转换为 pandas DF,对数据进行排序,然后在某些唯一约束列上删除重复项。

dataset = ds.dataset(history, filesystem, partitioning)
table = dataset.to_table(filter=filter_expression, columns=columns)
df = table.to_pandas().sort_values(sort_columns, ascending=True).drop_duplicates(unique_constraint, keep="last")
table = pa.Table.from_pandas(df=df, schema=table.schema, preserve_index=False)

# ds.write_dataset(final, filesystem, partitioning)

# I tend to write the final dataset using the legacy dataset so I can make use of the partition_filename_cb - that way I can have one file per date_id. Our visualization tool connects to these files directly
# container/dataset/date_id=20210127/20210127.parquet

pq.write_to_dataset(final, filesystem, partition_cols=["date_id"], use_legacy_dataset=True, partition_filename_cb=lambda x: str(x[-1]).split(".")[0] + ".parquet")

如果可能的话,最好取消转换为 pandas 然后再转换回 table。

编辑 2022 年 3 月:PyArrow 正在添加更多功能,尽管目前还没有这个功能。我现在的做法是:

def drop_duplicates(table: pa.Table, column_name: str) -> pa.Table:
    unique_values = pc.unique(table[column_name])
    unique_indices = [pc.index(table[column_name], value).as_py() for value in unique_values]
    mask = np.full((len(table)), False)
    mask[unique_indices] = True
    return table.filter(mask=mask)

//结束编辑

我看到了你的问题,因为我有一个类似的问题,我在工作中解决了它(由于 IP 问题,我不能 post 整个代码,但我会尽力回答我可以。我以前从来没有这样做过)

import pyarrow.compute as pc
import pyarrow as pa
import numpy as np

array = table.column(column_name)
dicts = {dct['values']: dct['counts'] for dct in pc.value_counts(array).to_pylist()}
for key, value in dicts.items():
    # do stuff

我使用 'value_counts' 来查找唯一值以及它们的数量 (https://arrow.apache.org/docs/python/generated/pyarrow.compute.value_counts.html)。然后我迭代了这些值。如果值为 1,我使用

选择了该行
mask = pa.array(np.array(array) == key)
row = table.filter(mask)

如果计数大于 1,我会再次使用 numpy 布尔数组作为掩码来选择第一个或最后一个。

迭代后就像pa.concat_tables(tables)

一样简单

警告:这是一个缓慢的过程。如果您需要一些快速而肮脏的东西,请尝试“独特”选项(也在我提供的 link 中)。

edit/extra::您可以通过在遍历字典时保持布尔掩码的 numpy 数组来使它有点 faster/less 内存密集型。那么最后你 return 一个“table.filter(mask=boolean_mask)”。 虽然我不知道如何计算速度...

edit2: (抱歉进行了多次编辑。我一直在进行大量重构,并试图让它更快地工作。)

您也可以尝试类似的方法:

def drop_duplicates(table: pa.Table, col_name: str) ->pa.Table:
    column_array = table.column(col_name)
    mask_x = np.full((table.shape[0]), False)
    _, mask_indices = np.unique(np.array(column_array), return_index=True)
    mask_x[mask_indices] = True
    return table.filter(mask=mask_x)

下面给出了一个不错的性能。对于具有 5 亿行 的 table,大约 2 分钟。我不这样做的原因 combine_chunks():有一个错误,如果尺寸太大,箭头似乎无法组合块数组。查看详情:https://issues.apache.org/jira/browse/ARROW-10172?src=confmacro

a = [len(tb3['ID'].chunk(i)) for i in range(len(tb3['ID'].chunks))]
c = np.array([np.arange(x) for x in a])
a = ([0]+a)[:-1]
c = pa.chunked_array(c+np.cumsum(a))

    
tb3= tb3.set_column(tb3.shape[1], 'index', c)
selector = tb3.group_by(['ID']).aggregate([("index", "min")])
    
tb3 = tb3.filter(pc.is_in(tb3['index'], value_set=selector['index_min']))

我发现 duckdb 可以在 group by 上提供更好的性能。将上面的最后两行更改为以下内容将提供 2 倍的加速:

import duckdb 
duck = duckdb.connect()
sql = "select first(index) as idx from tb3 group by ID"
duck_res = duck.execute(sql).fetch_arrow_table()
tb3 = tb3.filter(pc.is_in(tb3['index'], value_set=duck_res['idx']))