如何在 BigQuery Standard SQL 中进行可重复采样?

How to do repeatable sampling in BigQuery Standard SQL?

this blog 中,一位 Google Cloud 员工解释了如何在 BigQuery 中为机器学习对数据集进行可重复采样。这对于创建(和复制)train/validation/test 数据分区非常重要。

但是博客使用旧版 SQL,Google has now deprecated 支持标准 SQL。

您将如何重写下面显示的博客示例代码,但使用标准 SQL?

#legacySQL
SELECT
  date,
  airline,
  departure_airport,
  departure_schedule,
  arrival_airport,
  arrival_delay
FROM
  [bigquery-samples:airline_ontime_data.flights]
WHERE
  ABS(HASH(date)) % 10 < 8

标准 SQL 会这样重写查询:

#standardSQL
SELECT
  date,
  airline,
  departure_airport,
  departure_schedule,
  arrival_airport,
  arrival_delay
FROM
  `bigquery-samples.airline_ontime_data.flights`
WHERE
  ABS(MOD(FARM_FINGERPRINT(date), 10)) < 8

具体变化如下:

  • 用于将 Google Cloud 项目与 table 名称分隔开的句点(不是冒号)。
  • 反引号(不是方括号)以转义 table 名称中的连字符。
  • MOD 函数(不是 %)。
  • FARM_FINGERPRINT(不是 HASH)。这实际上是一个不同于 Legacy SQL 的 HASH 的哈希函数,正如博客所暗示的那样 wasn't in fact consistent over time

根据已接受的答案,提供一种更通用的方法来为每一行生成唯一键:

TO_JSON_STRING(STRUCT(col1, col2, ..., colN))

#standardSQL
SELECT
  date,
  airline,
  departure_airport,
  departure_schedule,
  arrival_airport,
  arrival_delay
FROM
  `bigquery-samples.airline_ontime_data.flights`
WHERE
  ABS(MOD(FARM_FINGERPRINT(TO_JSON_STRING(STRUCT(date, airline, arrival_delay))), 10)) < 8

如果没有唯一键标识每一行怎么办?

是的,您的数据集中可能存在设计重复的数据行,使用上述查询,样本集中包含所有或 none 个重复项。

根据您的数据集有多大,您可以尝试对源数据集进行排序并使用 window 函数为每一行生成一个 row_number。然后根据row_number做采样。这个技巧会一直有效,直到你在对数据集进行排序时遇到错误:

Resources exceeded during query execution: The query could not be executed in the allotted memory.

如果我遇到上面的错误怎么办

好吧,上面的方法更容易实现,但是如果你达到了极限,可以考虑做一些更复杂的事情:

  1. 生成一个去重的 table,其中包含它在原始数据集中出现的次数。
  2. 对行进行哈希处理后,根据 COUNT 增加选择行的几率。
  3. 由于您不想使用所有 COUNT 个重复项,您可以再次进行散列以确定样本集中应包含多大的重复项。 (不过在数学上一定有更好的方法)

我在练习中使用的更简洁的版本而不是冗长的 TO_JSON_STRING(STRUCT(col1, col2, ..., colN))

TO_JSON_STRING(t)

FORMAT('%t', t)   

如下例所示

#standardSQL
SELECT 
  date,
  airline,
  departure_airport,
  departure_schedule,
  arrival_airport,
  arrival_delay
FROM
  `bigquery-samples.airline_ontime_data.flights` t
WHERE
  MOD(ABS(FARM_FINGERPRINT(FORMAT('%t', t))), 10) < 8   

#standardSQL
SELECT 
  date,
  airline,
  departure_airport,
  departure_schedule,
  arrival_airport,
  arrival_delay
FROM
  `bigquery-samples.airline_ontime_data.flights` t
WHERE
  MOD(ABS(FARM_FINGERPRINT(TO_JSON_STRING(t))), 10) < 8