如何从 tf.tensor 中获取字符串值,其中 dtype 是字符串
how to get string value out of tf.tensor which dtype is string
我想使用 tf.data.Dataset.list_files 函数来提供我的数据集。
但是因为文件不是图片,需要手动加载
问题是 tf.data.Dataset.list_files 将变量作为 tf.tensor 传递,而我的 python 代码无法处理张量。
如何从 tf.tensor 获取字符串值。
dtype 是字符串。
train_dataset = tf.data.Dataset.list_files(PATH+'clean_4s_val/*.wav')
train_dataset = train_dataset.map(lambda x: load_audio_file(x))
def load_audio_file(file_path):
print("file_path: ", file_path)
# i want do something like string_path = convert_tensor_to_string(file_path)
file_path 是 Tensor("arg0:0", shape=(), dtype=string)
我使用 tensorflow 1.13.1 和 eager 模式。
提前致谢
你可以用tf.py_func
来换行load_audio_file()
。
import tensorflow as tf
tf.enable_eager_execution()
def load_audio_file(file_path):
# you should decode bytes type to string type
print("file_path: ",bytes.decode(file_path),type(bytes.decode(file_path)))
return file_path
train_dataset = tf.data.Dataset.list_files('clean_4s_val/*.wav')
train_dataset = train_dataset.map(lambda x: tf.py_func(load_audio_file, [x], [tf.string]))
for one_element in train_dataset:
print(one_element)
file_path: clean_4s_val/1.wav <class 'str'>
(<tf.Tensor: id=32, shape=(), dtype=string, numpy=b'clean_4s_val/1.wav'>,)
file_path: clean_4s_val/3.wav <class 'str'>
(<tf.Tensor: id=34, shape=(), dtype=string, numpy=b'clean_4s_val/3.wav'>,)
file_path: clean_4s_val/2.wav <class 'str'>
(<tf.Tensor: id=36, shape=(), dtype=string, numpy=b'clean_4s_val/2.wav'>,)
更新 TF 2
上述解决方案不适用于 TF 2(已通过 2.2.0 测试),即使将 tf.py_func
替换为 tf.py_function
,给出
InvalidArgumentError: TypeError: descriptor 'decode' requires a 'bytes' object but received a 'tensorflow.python.framework.ops.EagerTensor'
要使其在 TF 2 中运行,请进行以下更改:
- 删除
tf.enable_eager_execution()
(在TF 2中eager是enabled by default,你可以通过tf.executing_eagerly()
返回True
来验证)
- 将
tf.py_func
替换为tf.py_function
- 将
file_path
的所有 in-function 引用替换为 file_path.numpy()
如果您想做一些完全自定义的事情,那么您应该将代码包装在 tf.py_function
中。请记住,这将导致性能不佳。请在此处查看文档和示例:
https://www.tensorflow.org/api_docs/python/tf/data/Dataset#map
另一方面,如果您正在做一些通用的事情,那么您不需要将代码包装在 py_function
中,而是使用 tf.strings
模块中提供的任何方法。这些方法适用于字符串张量,并提供许多常用方法,如 split、join、len 等。这些方法不会对性能产生负面影响,它们将直接作用于张量和 return 修改后的张量。
在此处查看 tf.strings
的文档:https://www.tensorflow.org/api_docs/python/tf/strings
例如,假设您想从文件名中提取标签的名称,然后您可以编写如下代码:
ds.map(lambda x: tf.strings.split(x, sep='$')[1])
以上假设标签由$
分隔。
如果您真的只想将 Tensor 解包为其字符串内容 - 您需要序列化 TFRecord 才能使用 tf_example.SerializeToString() - 获取(可打印的)字符串值 - 请参阅 here
我想使用 tf.data.Dataset.list_files 函数来提供我的数据集。
但是因为文件不是图片,需要手动加载
问题是 tf.data.Dataset.list_files 将变量作为 tf.tensor 传递,而我的 python 代码无法处理张量。
如何从 tf.tensor 获取字符串值。 dtype 是字符串。
train_dataset = tf.data.Dataset.list_files(PATH+'clean_4s_val/*.wav')
train_dataset = train_dataset.map(lambda x: load_audio_file(x))
def load_audio_file(file_path):
print("file_path: ", file_path)
# i want do something like string_path = convert_tensor_to_string(file_path)
file_path 是 Tensor("arg0:0", shape=(), dtype=string)
我使用 tensorflow 1.13.1 和 eager 模式。
提前致谢
你可以用tf.py_func
来换行load_audio_file()
。
import tensorflow as tf
tf.enable_eager_execution()
def load_audio_file(file_path):
# you should decode bytes type to string type
print("file_path: ",bytes.decode(file_path),type(bytes.decode(file_path)))
return file_path
train_dataset = tf.data.Dataset.list_files('clean_4s_val/*.wav')
train_dataset = train_dataset.map(lambda x: tf.py_func(load_audio_file, [x], [tf.string]))
for one_element in train_dataset:
print(one_element)
file_path: clean_4s_val/1.wav <class 'str'>
(<tf.Tensor: id=32, shape=(), dtype=string, numpy=b'clean_4s_val/1.wav'>,)
file_path: clean_4s_val/3.wav <class 'str'>
(<tf.Tensor: id=34, shape=(), dtype=string, numpy=b'clean_4s_val/3.wav'>,)
file_path: clean_4s_val/2.wav <class 'str'>
(<tf.Tensor: id=36, shape=(), dtype=string, numpy=b'clean_4s_val/2.wav'>,)
更新 TF 2
上述解决方案不适用于 TF 2(已通过 2.2.0 测试),即使将 tf.py_func
替换为 tf.py_function
,给出
InvalidArgumentError: TypeError: descriptor 'decode' requires a 'bytes' object but received a 'tensorflow.python.framework.ops.EagerTensor'
要使其在 TF 2 中运行,请进行以下更改:
- 删除
tf.enable_eager_execution()
(在TF 2中eager是enabled by default,你可以通过tf.executing_eagerly()
返回True
来验证) - 将
tf.py_func
替换为tf.py_function
- 将
file_path
的所有 in-function 引用替换为file_path.numpy()
如果您想做一些完全自定义的事情,那么您应该将代码包装在 tf.py_function
中。请记住,这将导致性能不佳。请在此处查看文档和示例:
https://www.tensorflow.org/api_docs/python/tf/data/Dataset#map
另一方面,如果您正在做一些通用的事情,那么您不需要将代码包装在 py_function
中,而是使用 tf.strings
模块中提供的任何方法。这些方法适用于字符串张量,并提供许多常用方法,如 split、join、len 等。这些方法不会对性能产生负面影响,它们将直接作用于张量和 return 修改后的张量。
在此处查看 tf.strings
的文档:https://www.tensorflow.org/api_docs/python/tf/strings
例如,假设您想从文件名中提取标签的名称,然后您可以编写如下代码:
ds.map(lambda x: tf.strings.split(x, sep='$')[1])
以上假设标签由$
分隔。
如果您真的只想将 Tensor 解包为其字符串内容 - 您需要序列化 TFRecord 才能使用 tf_example.SerializeToString() - 获取(可打印的)字符串值 - 请参阅 here