如何使用 Dataflow 在 apache beam 中跳过 io 级别的错误元素?

How to skip erroneous elements at io level in apache beam with Dataflow?

我正在对存储在 GCP 中的 tfrecords 进行一些分析,但是文件中的一些 tfrecords 已损坏,所以当我 运行 我的管道出现超过四个错误时,我的管道由于 this。我认为这是 DataFlowRunner 的约束,而不是 beam 的约束。

这是我的处理脚本

import argparse
import apache_beam as beam
from apache_beam.options.pipeline_options import PipelineOptions
from apache_beam.metrics.metric import Metrics

from apache_beam.runners.direct import direct_runner
import tensorflow as tf

input_ = "path_to_bucket"


def _parse_example(serialized_example):
  """Return inputs and targets Tensors from a serialized tf.Example."""
  data_fields = {
      "inputs": tf.io.VarLenFeature(tf.int64),
      "targets": tf.io.VarLenFeature(tf.int64)
  }
  parsed = tf.io.parse_single_example(serialized_example, data_fields)
  inputs = tf.sparse.to_dense(parsed["inputs"])
  targets = tf.sparse.to_dense(parsed["targets"])
  return inputs, targets


class MyFnDo(beam.DoFn):

  def __init__(self):
    beam.DoFn.__init__(self)
    self.input_tokens = Metrics.distribution(self.__class__, 'input_tokens')
    self.output_tokens = Metrics.distribution(self.__class__, 'output_tokens')
    self.num_examples = Metrics.counter(self.__class__, 'num_examples')
    self.decode_errors = Metrics.counter(self.__class__, 'decode_errors')

  def process(self, element):
    # inputs = element.features.feature['inputs'].int64_list.value
    # outputs = element.features.feature['outputs'].int64_list.value
    try:
      inputs, outputs = _parse_example(element)
      self.input_tokens.update(len(inputs))
      self.output_tokens.update(len(outputs))
      self.num_examples.inc()
    except Exception:
      self.decode_errors.inc()



def main(argv):
  parser = argparse.ArgumentParser()
  parser.add_argument('--input', dest='input', default=input_, help='input tfrecords')
  # parser.add_argument('--output', dest='output', default='gs://', help='output file')

  known_args, pipeline_args = parser.parse_known_args(argv)
  pipeline_options = PipelineOptions(pipeline_args)

  with beam.Pipeline(options=pipeline_options) as p:
    tfrecords = p | "Read TFRecords" >> beam.io.ReadFromTFRecord(known_args.input,
                                                                 coder=beam.coders.ProtoCoder(tf.train.Example))
    tfrecords | "count mean" >> beam.ParDo(MyFnDo())


if __name__ == '__main__':
    main(None)

所以基本上我如何才能在分析时跳过损坏的 tfrecords 并记录它们的数量?

它有一个概念上的问题,beam.io.ReadFromTFRecord 从单个 tfrecords 读取(可以共享到多个文件),而我给出的是许多单独 tfrecords 的列表,因此它导致错误。从 ReadFromTFRecord 切换到 ReadAllFromTFRecord 解决了我的问题。

p = beam.Pipeline(runner=direct_runner.DirectRunner())
tfrecords = p | beam.Create(tf.io.gfile.glob(input_)) | ReadAllFromTFRecord(coder=beam.coders.ProtoCoder(tf.train.Example))
tfrecords | beam.ParDo(MyFnDo())