Beam Java 带有 TFRecord 和压缩 GZIP 的 SDK

Beam Java SDK with TFRecord and Compression GZIP

我们经常使用 Beam Java SDK(和 Google Cloud Dataflow 到 运行 批处理作业),我们注意到一些奇怪的事情(可能是错误?)尝试将 TFRecordIOCompression.GZIP 一起使用。我们能够想出一些示例代码来重现我们面临的错误。

明确地说,我们使用的是 Beam Java SDK 2.4。

假设我们有PCollection<byte[]>,它可以是proto messages的PC,比如byte[]格式。 我们通常使用 Base64 编码(换行分隔字符串)或使用 TFRecordIO(不压缩)将其写入 GCS(Google 云存储)。我们在很长一段时间内以这种方式从 GCS 读取数据都没有问题(前者超过 2.5 年,后者约为 1.5 年)。

最近,我们尝试 TFRecordIO 使用 Compression.GZIP 选项,有时 我们会遇到异常,因为数据被视为无效(在读取时) .数据本身(gzip 文件)没有损坏,我们测试了各种东西,得出以下结论。

当在 TFRecordIO 下压缩的 byte[] 高于特定阈值时(我会说当等于或高于 8192 时),则 TFRecordIO.read().withCompression(Compression.GZIP) 将不起作用。 具体来说,会抛出如下异常:

Exception in thread "main" java.lang.IllegalStateException: Invalid data
    at org.apache.beam.sdk.repackaged.com.google.common.base.Preconditions.checkState(Preconditions.java:444)
    at org.apache.beam.sdk.io.TFRecordIO$TFRecordCodec.read(TFRecordIO.java:642)
    at org.apache.beam.sdk.io.TFRecordIO$TFRecordSource$TFRecordReader.readNextRecord(TFRecordIO.java:526)
    at org.apache.beam.sdk.io.CompressedSource$CompressedReader.readNextRecord(CompressedSource.java:426)
    at org.apache.beam.sdk.io.FileBasedSource$FileBasedReader.advanceImpl(FileBasedSource.java:473)
    at org.apache.beam.sdk.io.FileBasedSource$FileBasedReader.startImpl(FileBasedSource.java:468)
    at org.apache.beam.sdk.io.OffsetBasedSource$OffsetBasedReader.start(OffsetBasedSource.java:261)
    at org.apache.beam.runners.direct.BoundedReadEvaluatorFactory$BoundedReadEvaluator.processElement(BoundedReadEvaluatorFactory.java:141)
    at org.apache.beam.runners.direct.DirectTransformExecutor.processElements(DirectTransformExecutor.java:161)
    at org.apache.beam.runners.direct.DirectTransformExecutor.run(DirectTransformExecutor.java:125)
    at java.util.concurrent.Executors$RunnableAdapter.call(Executors.java:511)
    at java.util.concurrent.FutureTask.run(FutureTask.java:266)
    at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
    at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
    at java.lang.Thread.run(Thread.java:748)

这个很容易复现,可以参考最后的代码。您还将看到有关字节数组长度的注释(当我测试各种大小时,我得出的结论是 8192 是神奇的数字)。

所以我想知道这是一个错误还是已知问题——我在 Apache Beam 的问题跟踪器上找不到任何与此接近的内容 here 但如果还有另一个 forum/site 我需要检查,请告诉我! 如果这确实是一个错误,那么报告此问题的正确渠道是什么?


下面的代码可以重现我们遇到的错误。

成功的 运行(参数为 1、39、100)将在末尾显示以下消息:

------------ counter metrics from CountDoFn
[counter]             plain_base64_proto_array_len: 8126
[counter]                    plain_base64_proto_in:   1
[counter]               plain_base64_proto_val_cnt:  39
[counter]              tfrecord_gz_proto_array_len: 8126
[counter]                     tfrecord_gz_proto_in:   1
[counter]                tfrecord_gz_proto_val_cnt:  39
[counter]          tfrecord_uncomp_proto_array_len: 8126
[counter]                 tfrecord_uncomp_proto_in:   1
[counter]            tfrecord_uncomp_proto_val_cnt:  39

使用参数 (1, 40, 100) 会使字节数组长度超过 8192,它会抛出上述异常。

您可以调整参数(在 CreateRandomProtoData DoFn 内)以了解为什么 byte[] 的长度被压缩很重要。 使用以下 protoc-gen Java class(对于上面主要代码中使用的 TestProto 也可能有所帮助。这里是:gist link

参考资料: 主要代码:

package exp.moloco.dataflow2.compression; // NOTE: Change appropriately.

import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Random;
import java.util.TreeMap;

import org.apache.beam.runners.direct.DirectRunner;
import org.apache.beam.sdk.Pipeline;
import org.apache.beam.sdk.PipelineResult;
import org.apache.beam.sdk.io.Compression;
import org.apache.beam.sdk.io.TFRecordIO;
import org.apache.beam.sdk.io.TextIO;
import org.apache.beam.sdk.metrics.Counter;
import org.apache.beam.sdk.metrics.MetricResult;
import org.apache.beam.sdk.metrics.Metrics;
import org.apache.beam.sdk.metrics.MetricsFilter;
import org.apache.beam.sdk.options.PipelineOptions;
import org.apache.beam.sdk.options.PipelineOptionsFactory;
import org.apache.beam.sdk.transforms.Create;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.values.PCollection;
import org.apache.commons.codec.binary.Base64;
import org.joda.time.DateTime;
import org.joda.time.DateTimeZone;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.google.protobuf.InvalidProtocolBufferException;

import com.moloco.dataflow.test.Whosebug.TestProto;
import com.moloco.dataflow2.Main;

// @formatter:off
// This code uses TestProto (java class) that is generated by protoc.
// The message definition is as follows (in proto3, but it shouldn't matter):
// message TestProto {
//   int64 count = 1;
//   string name = 2;
//   repeated string values = 3;
// }
// Note that this code does not depend on whether this proto is used,
// or any other byte[] is used (see CreateRandomData DoFn later which generates the data being used in the code).
// We tested both, but are presenting this as a concrete example of how (our) code in production can be affected.
// @formatter:on

public class CompressionTester {
  private static final Logger LOG = LoggerFactory.getLogger(CompressionTester.class);

  static final List<String> lines = Arrays.asList("some dummy string that will not used in this job.");

  // Some GCS buckets where data will be written to.
  // %s will be replaced by some timestamped String for easy debugging.
  static final String PATH_TO_GCS_PLAIN_BASE64 = Main.SOME_BUCKET + "/comp-test/%s/output-plain-base64";
  static final String PATH_TO_GCS_TFRECORD_UNCOMP = Main.SOME_BUCKET + "/comp-test/%s/output-tfrecord-uncompressed";
  static final String PATH_TO_GCS_TFRECORD_GZ = Main.SOME_BUCKET + "/comp-test/%s/output-tfrecord-gzip";

  // This DoFn reads byte[] which represents a proto message (TestProto).
  // It simply counts the number of proto objects it processes
  // as well as the number of Strings each proto object contains.
  // When the pipeline terminates, the values of the Counters will be printed out.
  static class CountDoFn extends DoFn<byte[], TestProto> {

    private final Counter protoIn;
    private final Counter protoValuesCnt;
    private final Counter protoByteArrayLength;

    public CountDoFn(String name) {
      protoIn = Metrics.counter(this.getClass(), name + "_proto_in");
      protoValuesCnt = Metrics.counter(this.getClass(), name + "_proto_val_cnt");
      protoByteArrayLength = Metrics.counter(this.getClass(), name + "_proto_array_len");
    }

    @ProcessElement
    public void processElement(ProcessContext c) throws InvalidProtocolBufferException {
      protoIn.inc();
      TestProto tp = TestProto.parseFrom(c.element());
      protoValuesCnt.inc(tp.getValuesCount());
      protoByteArrayLength.inc(c.element().length);
    }
  }

  // This DoFn emits a number of TestProto objects as byte[].
  // Input to this DoFn is ignored (not used).
  // Each TestProto object contains three fields: count (int64), name (string), and values (repeated string).
  // The three parameters in DoFn determines
  // (1) the number of proto objects to be generated,
  // (2) the number of (repeated) strings to be added to each proto object, and
  // (3) the length of (each) string.
  // TFRecord with Compression (when reading) fails when the parameters are 1, 40, 100, for instance.
  // TFRecord with Compression (when reading) succeeds when the parameters are 1, 39, 100, for instance.
  static class CreateRandomProtoData extends DoFn<String, byte[]> {

    static final int NUM_PROTOS = 1; // Total number of TestProto objects to be emitted by this DoFn.
    static final int NUM_STRINGS = 40; // Total number of strings in each TestProto object ('repeated string').
    static final int STRING_LEN = 100; // Length of each string object.

    // Returns a random string of length len.
    // For debugging purposes, the string only contains upper-case English alphabets.
    static String getRandomString(Random rd, int len) {
      StringBuffer sb = new StringBuffer();
      for (int i = 0; i < len; i++) {
        sb.append('A' + (rd.nextInt(26)));
      }
      return sb.toString();
    }

    // Returns a randomly generated TestProto object.
    // Each string is generated randomly using getRandomString().
    static TestProto getRandomProto(Random rd) {
      TestProto.Builder tpBuilder = TestProto.newBuilder();

      tpBuilder.setCount(rd.nextInt());
      tpBuilder.setName(getRandomString(rd, STRING_LEN));
      for (int i = 0; i < NUM_STRINGS; i++) {
        tpBuilder.addValues(getRandomString(rd, STRING_LEN));
      }

      return tpBuilder.build();
    }

    // Emits TestProto objects are byte[].
    @ProcessElement
    public void processElement(ProcessContext c) {
      // For debugging purposes, we set the seed here.
      Random rd = new Random();
      rd.setSeed(132475);

      for (int n = 0; n < NUM_PROTOS; n++) {
        byte[] data = getRandomProto(rd).toByteArray();
        c.output(data);
        // With parameters (1, 39, 100), the array length is 8126. It works fine.
        // With parameters (1, 40, 100), the array length is 8329. It breaks TFRecord with GZIP.
        System.out.println("\n--------------------------\n");
        System.out.println("byte array length = " + data.length);
        System.out.println("\n--------------------------\n");
      }
    }
  }

  public static void execute() {
    PipelineOptions options = PipelineOptionsFactory.create();
    options.setJobName("compression-tester");
    options.setRunner(DirectRunner.class);

    // For debugging purposes, write files under 'gcsSubDir' so we can easily distinguish.
    final String gcsSubDir =
        String.format("%s-%d", DateTime.now(DateTimeZone.UTC), DateTime.now(DateTimeZone.UTC).getMillis());

    // Write PCollection<TestProto> in 3 different ways to GCS.
    {
      Pipeline pipeline = Pipeline.create(options);

      // Create dummy data which is a PCollection of byte arrays (each array representing a proto message).
      PCollection<byte[]> data = pipeline.apply(Create.of(lines)).apply(ParDo.of(new CreateRandomProtoData()));

      // 1. Write as plain-text with base64 encoding.
      data.apply(ParDo.of(new DoFn<byte[], String>() {
        @ProcessElement
        public void processElement(ProcessContext c) {
          c.output(new String(Base64.encodeBase64(c.element())));
        }
      })).apply(TextIO.write().to(String.format(PATH_TO_GCS_PLAIN_BASE64, gcsSubDir)).withNumShards(1));

      // 2. Write as TFRecord.
      data.apply(TFRecordIO.write().to(String.format(PATH_TO_GCS_TFRECORD_UNCOMP, gcsSubDir)).withNumShards(1));

      // 3. Write as TFRecord-gzip.
      data.apply(TFRecordIO.write().withCompression(Compression.GZIP)
          .to(String.format(PATH_TO_GCS_TFRECORD_GZ, gcsSubDir)).withNumShards(1));

      pipeline.run().waitUntilFinish();
    }

    LOG.info("-------------------------------------------");
    LOG.info("               READ TEST BEGINS ");
    LOG.info("-------------------------------------------");

    // Read PCollection<TestProto> in 3 different ways from GCS.
    {
      Pipeline pipeline = Pipeline.create(options);

      // 1. Read as plain-text.
      pipeline.apply(TextIO.read().from(String.format(PATH_TO_GCS_PLAIN_BASE64, gcsSubDir) + "*"))
          .apply(ParDo.of(new DoFn<String, byte[]>() {
            @ProcessElement
            public void processElement(ProcessContext c) {
              c.output(Base64.decodeBase64(c.element()));
            }
          })).apply("plain-base64", ParDo.of(new CountDoFn("plain_base64")));

      // 2. Read as TFRecord -> byte array.
      pipeline.apply(TFRecordIO.read().from(String.format(PATH_TO_GCS_TFRECORD_UNCOMP, gcsSubDir) + "*"))
          .apply("tfrecord-uncomp", ParDo.of(new CountDoFn("tfrecord_uncomp")));

      // 3. Read as TFRecord-gz -> byte array.
      // This seems to fail when 'data size' becomes large.
      pipeline
          .apply(TFRecordIO.read().withCompression(Compression.GZIP)
              .from(String.format(PATH_TO_GCS_TFRECORD_GZ, gcsSubDir) + "*"))
          .apply("tfrecord_gz", ParDo.of(new CountDoFn("tfrecord_gz")));

      // 4. Run pipeline.
      PipelineResult res = pipeline.run();
      res.waitUntilFinish();

      // Check CountDoFn's metrics.
      // The numbers should match.
      Map<String, Long> counterValues = new TreeMap<String, Long>();
      for (MetricResult<Long> counter : res.metrics().queryMetrics(MetricsFilter.builder().build()).counters()) {
        counterValues.put(counter.name().name(), counter.committed());
      }
      StringBuffer sb = new StringBuffer();
      sb.append("\n------------ counter metrics from CountDoFn\n");
      for (Entry<String, Long> entry : counterValues.entrySet()) {
        sb.append(String.format("[counter] %40s: %5d\n", entry.getKey(), entry.getValue()));
      }
      LOG.info(sb.toString());
    }
  }
}

这看起来很像是 TFRecordIO 中的错误。 Channel.read() 可以读取的字节少于输入缓冲区的容量。 8192 似乎是 GzipCompressorInputStream 中的缓冲区大小。我提交了 https://issues.apache.org/jira/browse/BEAM-5412

是个bug,请看:https://issues.apache.org/jira/browse/BEAM-7695,我已经解决了