将输入添加到 tensorflow.js 中的 `tf.data.generator`

Adding inputs to a `tf.data.generator` in tensorflow.js

我正在尝试创建一个数据生成器,我已验证它可以在纯 js 中独立运行。它的 TFJS 文档在这里,有两个例子: https://js.tensorflow.org/api/latest/#data.generator

我想使用 tf.data.generator,因为此数据集需要进行精细的预处理。一个最小的例子如下:

const tf = require('@tensorflow/tfjs-node');
class dataGeneratorGenerator {
    constructor(test) {
        this.test = test
    }
    * dataGenerator() {
        let len = this.test.length
        let idx = 0
        while (idx < len) {
            idx++
            console.log(idx)
            yield this.test[idx]
        }
    }
}
let dgg = new dataGeneratorGenerator(['hi', 'hi2', 'hi3'])
let trainDs = tf.data.generator(dgg.dataGenerator);
trainDs.forEachAsync(e => console.log(e));

错误如下:

TypeError: Error thrown while iterating through a dataset: Cannot read property 'test' of undefined

在纯 javascript 中迭代我们的数据生成器有效:

let dgg = new dataGeneratorGenerator(['hi', 'hi2', 'hi3'])
let dg = dgg.dataGenerator()
console.log(dgg.next())
console.log(dgg.next())
console.log(dgg.next())

我的理解是我们只是将 dataGenerator 传递给 tf.data.generator 而不是整个 class。那么,如何将变量输入tf.data.generator呢?谢谢。

可以简单地使用箭头函数。

const tf = require('@tensorflow/tfjs-node');

function* dataGenerator(test) {
    let len = test.length
    let idx = 0
    while (idx < len) {
        idx++
        console.log(idx)
    }
}

let trainDs = tf.data.generator(() => dataGenerator(['hi', 'hi2', 'hi3']));

trainDs.forEachAsync(e => console.log(e));