将输入添加到 tensorflow.js 中的 `tf.data.generator`
Adding inputs to a `tf.data.generator` in tensorflow.js
我正在尝试创建一个数据生成器,我已验证它可以在纯 js 中独立运行。它的 TFJS 文档在这里,有两个例子:
https://js.tensorflow.org/api/latest/#data.generator
我想使用 tf.data.generator
,因为此数据集需要进行精细的预处理。一个最小的例子如下:
const tf = require('@tensorflow/tfjs-node');
class dataGeneratorGenerator {
constructor(test) {
this.test = test
}
* dataGenerator() {
let len = this.test.length
let idx = 0
while (idx < len) {
idx++
console.log(idx)
yield this.test[idx]
}
}
}
let dgg = new dataGeneratorGenerator(['hi', 'hi2', 'hi3'])
let trainDs = tf.data.generator(dgg.dataGenerator);
trainDs.forEachAsync(e => console.log(e));
错误如下:
TypeError: Error thrown while iterating through a dataset: Cannot read property 'test' of undefined
在纯 javascript 中迭代我们的数据生成器有效:
let dgg = new dataGeneratorGenerator(['hi', 'hi2', 'hi3'])
let dg = dgg.dataGenerator()
console.log(dgg.next())
console.log(dgg.next())
console.log(dgg.next())
我的理解是我们只是将 dataGenerator
传递给 tf.data.generator
而不是整个 class。那么,如何将变量输入tf.data.generator
呢?谢谢。
可以简单地使用箭头函数。
const tf = require('@tensorflow/tfjs-node');
function* dataGenerator(test) {
let len = test.length
let idx = 0
while (idx < len) {
idx++
console.log(idx)
}
}
let trainDs = tf.data.generator(() => dataGenerator(['hi', 'hi2', 'hi3']));
trainDs.forEachAsync(e => console.log(e));
我正在尝试创建一个数据生成器,我已验证它可以在纯 js 中独立运行。它的 TFJS 文档在这里,有两个例子: https://js.tensorflow.org/api/latest/#data.generator
我想使用 tf.data.generator
,因为此数据集需要进行精细的预处理。一个最小的例子如下:
const tf = require('@tensorflow/tfjs-node');
class dataGeneratorGenerator {
constructor(test) {
this.test = test
}
* dataGenerator() {
let len = this.test.length
let idx = 0
while (idx < len) {
idx++
console.log(idx)
yield this.test[idx]
}
}
}
let dgg = new dataGeneratorGenerator(['hi', 'hi2', 'hi3'])
let trainDs = tf.data.generator(dgg.dataGenerator);
trainDs.forEachAsync(e => console.log(e));
错误如下:
TypeError: Error thrown while iterating through a dataset: Cannot read property 'test' of undefined
在纯 javascript 中迭代我们的数据生成器有效:
let dgg = new dataGeneratorGenerator(['hi', 'hi2', 'hi3'])
let dg = dgg.dataGenerator()
console.log(dgg.next())
console.log(dgg.next())
console.log(dgg.next())
我的理解是我们只是将 dataGenerator
传递给 tf.data.generator
而不是整个 class。那么,如何将变量输入tf.data.generator
呢?谢谢。
可以简单地使用箭头函数。
const tf = require('@tensorflow/tfjs-node');
function* dataGenerator(test) {
let len = test.length
let idx = 0
while (idx < len) {
idx++
console.log(idx)
}
}
let trainDs = tf.data.generator(() => dataGenerator(['hi', 'hi2', 'hi3']));
trainDs.forEachAsync(e => console.log(e));