如何使用 tf.data.Dataset.from_generator() 将参数发送到生成器函数?
How do you send arguments to a generator function using tf.data.Dataset.from_generator()?
我想使用 from_generator()
函数创建多个 tf.data.Dataset
。我想向生成器函数 (raw_data_gen
) 发送一个参数。这个想法是生成器函数将根据发送的参数产生不同的数据。通过这种方式,我希望 raw_data_gen
能够提供训练、验证或测试数据。
training_dataset = tf.data.Dataset.from_generator(raw_data_gen, (tf.float32, tf.uint8), ([None, 1], [None]), args=([1]))
validation_dataset = tf.data.Dataset.from_generator(raw_data_gen, (tf.float32, tf.uint8), ([None, 1], [None]), args=([2]))
test_dataset = tf.data.Dataset.from_generator(raw_data_gen, (tf.float32, tf.uint8), ([None, 1], [None]), args=([3]))
当我尝试以这种方式调用 from_generator()
时收到的错误消息是:
TypeError: from_generator() got an unexpected keyword argument 'args'
这是 raw_data_gen
函数,但我不确定您是否需要它,因为我的预感是问题出在 from_generator()
:
的调用上
def raw_data_gen(train_val_or_test):
if train_val_or_test == 1:
#For every filename collected in the list
for filename, lab in training_filepath_label_dict.items():
raw_data, samplerate = soundfile.read(filename)
try: #assume the audio is stereo, ready to be sliced
raw_data = raw_data[:,0] #raw_data is a np.array, just take first channel with slice
except IndexError:
pass #this must be mono audio
yield raw_data, lab
elif train_val_or_test == 2:
#For every filename collected in the list
for filename, lab in validation_filepath_label_dict.items():
raw_data, samplerate = soundfile.read(filename)
try: #assume the audio is stereo, ready to be sliced
raw_data = raw_data[:,0] #raw_data is a np.array, just take first channel with slice
except IndexError:
pass #this must be mono audio
yield raw_data, lab
elif train_val_or_test == 3:
#For every filename collected in the list
for filename, lab in test_filepath_label_dict.items():
raw_data, samplerate = soundfile.read(filename)
try: #assume the audio is stereo, ready to be sliced
raw_data = raw_data[:,0] #raw_data is a np.array, just take first channel with slice
except IndexError:
pass #this must be mono audio
yield raw_data, lab
else:
print("generator function called with an argument not in [1, 2, 3]")
raise ValueError()
您需要根据 raw_data_gen
定义一个不接受任何参数的新函数。您可以使用 lambda
关键字来执行此操作。
training_dataset = tf.data.Dataset.from_generator(lambda: raw_data_gen(train_val_or_test=1), (tf.float32, tf.uint8), ([None, 1], [None]))
...
现在,我们将一个函数传递给 from_generator
,它不接受任何参数,但它会简单地充当 raw_data_gen
并将参数设置为 1。您可以使用相同的方案对于验证集和测试集,分别通过 2 和 3。
对于 Tensorflow 2.4:
training_dataset = tf.data.Dataset.from_generator(
raw_data_gen,
args=(1),
output_types=(tf.float32, tf.uint8),
output_shapes=([None, 1], [None]))
我想使用 from_generator()
函数创建多个 tf.data.Dataset
。我想向生成器函数 (raw_data_gen
) 发送一个参数。这个想法是生成器函数将根据发送的参数产生不同的数据。通过这种方式,我希望 raw_data_gen
能够提供训练、验证或测试数据。
training_dataset = tf.data.Dataset.from_generator(raw_data_gen, (tf.float32, tf.uint8), ([None, 1], [None]), args=([1]))
validation_dataset = tf.data.Dataset.from_generator(raw_data_gen, (tf.float32, tf.uint8), ([None, 1], [None]), args=([2]))
test_dataset = tf.data.Dataset.from_generator(raw_data_gen, (tf.float32, tf.uint8), ([None, 1], [None]), args=([3]))
当我尝试以这种方式调用 from_generator()
时收到的错误消息是:
TypeError: from_generator() got an unexpected keyword argument 'args'
这是 raw_data_gen
函数,但我不确定您是否需要它,因为我的预感是问题出在 from_generator()
:
def raw_data_gen(train_val_or_test):
if train_val_or_test == 1:
#For every filename collected in the list
for filename, lab in training_filepath_label_dict.items():
raw_data, samplerate = soundfile.read(filename)
try: #assume the audio is stereo, ready to be sliced
raw_data = raw_data[:,0] #raw_data is a np.array, just take first channel with slice
except IndexError:
pass #this must be mono audio
yield raw_data, lab
elif train_val_or_test == 2:
#For every filename collected in the list
for filename, lab in validation_filepath_label_dict.items():
raw_data, samplerate = soundfile.read(filename)
try: #assume the audio is stereo, ready to be sliced
raw_data = raw_data[:,0] #raw_data is a np.array, just take first channel with slice
except IndexError:
pass #this must be mono audio
yield raw_data, lab
elif train_val_or_test == 3:
#For every filename collected in the list
for filename, lab in test_filepath_label_dict.items():
raw_data, samplerate = soundfile.read(filename)
try: #assume the audio is stereo, ready to be sliced
raw_data = raw_data[:,0] #raw_data is a np.array, just take first channel with slice
except IndexError:
pass #this must be mono audio
yield raw_data, lab
else:
print("generator function called with an argument not in [1, 2, 3]")
raise ValueError()
您需要根据 raw_data_gen
定义一个不接受任何参数的新函数。您可以使用 lambda
关键字来执行此操作。
training_dataset = tf.data.Dataset.from_generator(lambda: raw_data_gen(train_val_or_test=1), (tf.float32, tf.uint8), ([None, 1], [None]))
...
现在,我们将一个函数传递给 from_generator
,它不接受任何参数,但它会简单地充当 raw_data_gen
并将参数设置为 1。您可以使用相同的方案对于验证集和测试集,分别通过 2 和 3。
对于 Tensorflow 2.4:
training_dataset = tf.data.Dataset.from_generator(
raw_data_gen,
args=(1),
output_types=(tf.float32, tf.uint8),
output_shapes=([None, 1], [None]))