在 python 多处理队列中推进 tensorflow 数据集迭代器
Advancing tensorflow dataset iterator in python multiprocessing Queue
有什么方法可以移动这个例子中的迭代器吗?
import tensorflow as tf
import numpy as np
from multiprocessing import Process, Queue
def store(batch, queue):
while True:
queue.put(batch)
if __name__=='__main__':
pqueue = Queue()
a1 = np.arange(1000)
m = tf.data.Dataset.from_tensor_slices(a1).repeat().batch(1)
iter_m = m.make_one_shot_iterator()
m_init_ops = iter_m.make_initializer(m)
next_m = iter_m.get_next()
with tf.Session() as sess:
batch = sess.run(next_m)
pp_process = Process(target=store,args=(batch, pqueue,))
pp_process.daemon = True
pp_process.start()
for i in range(10):
print(pqueue.get())
我的想法是将处理后的数据存储在队列中,以供tensorflow访问以进行训练,不幸的是我无法推进迭代器。任何建议将不胜感激。
当前输出为
[0]
[0]
[0]
[0]
[0]
[0]
[0]
[0]
[0]
[0]
Tensorflow 多线程
迭代器没有前进,因为您在技术上只执行一次 get_next 操作:sess.run(next_m)
。如果您只使用 tensorflow 多线程,只需将其移至 store
函数即可获得所需结果:
def store(sess, next_m, queue):
while True:
queue.put(sess.run(next_m))
# batch = sess.run(next_m) <- Remove
pp_process = Thread(target=store,args=(sess, next_m, pqueue,)) # <- Thread with correct args passed
Tensorflow 多处理
但是,对于多处理,您还应确保在已经创建会话后永远不会实例化(派生)新进程,因为会话对象不可序列化。
在您的情况下,您可以简单地在存储函数中创建一个新会话并在分叉后启动主会话:
from multiprocessing import Process, Queue
import numpy as np
import tensorflow as tf
def store(next_m, queue):
with tf.Session() as sess:
while True:
queue.put(sess.run(next_m))
if __name__ == '__main__':
...
pp_process = Process(target=store, args=(next_m, pqueue,))
pp_process.daemon = True
pp_process.start() # <- Fork before starting this session!
with tf.Session() as sess:
for i in range(10):
print(pqueue.get())
有什么方法可以移动这个例子中的迭代器吗?
import tensorflow as tf
import numpy as np
from multiprocessing import Process, Queue
def store(batch, queue):
while True:
queue.put(batch)
if __name__=='__main__':
pqueue = Queue()
a1 = np.arange(1000)
m = tf.data.Dataset.from_tensor_slices(a1).repeat().batch(1)
iter_m = m.make_one_shot_iterator()
m_init_ops = iter_m.make_initializer(m)
next_m = iter_m.get_next()
with tf.Session() as sess:
batch = sess.run(next_m)
pp_process = Process(target=store,args=(batch, pqueue,))
pp_process.daemon = True
pp_process.start()
for i in range(10):
print(pqueue.get())
我的想法是将处理后的数据存储在队列中,以供tensorflow访问以进行训练,不幸的是我无法推进迭代器。任何建议将不胜感激。
当前输出为
[0]
[0]
[0]
[0]
[0]
[0]
[0]
[0]
[0]
[0]
Tensorflow 多线程
迭代器没有前进,因为您在技术上只执行一次 get_next 操作:sess.run(next_m)
。如果您只使用 tensorflow 多线程,只需将其移至 store
函数即可获得所需结果:
def store(sess, next_m, queue):
while True:
queue.put(sess.run(next_m))
# batch = sess.run(next_m) <- Remove
pp_process = Thread(target=store,args=(sess, next_m, pqueue,)) # <- Thread with correct args passed
Tensorflow 多处理
但是,对于多处理,您还应确保在已经创建会话后永远不会实例化(派生)新进程,因为会话对象不可序列化。
在您的情况下,您可以简单地在存储函数中创建一个新会话并在分叉后启动主会话:
from multiprocessing import Process, Queue
import numpy as np
import tensorflow as tf
def store(next_m, queue):
with tf.Session() as sess:
while True:
queue.put(sess.run(next_m))
if __name__ == '__main__':
...
pp_process = Process(target=store, args=(next_m, pqueue,))
pp_process.daemon = True
pp_process.start() # <- Fork before starting this session!
with tf.Session() as sess:
for i in range(10):
print(pqueue.get())