在工作人员上加载本地(不可序列化)对象
Load local (unserializable) objects on workers
我正在尝试结合使用 Dataflow 和 Tensorflow 进行预测。这些预测发生在工人身上,我目前正在通过 startup_bundle()
加载模型。喜欢这里:
class PredictDoFn(beam.DoFn):
def start_bundle(self):
self.model = load_model_from_file()
def process(self, element):
...
我目前的问题是,即使我处理了 1000 个元素,startup_bundle()
函数也会被调用多次(至少 10 次),而不是像我希望的那样每次工作一次。这会显着减慢管道速度,因为模型需要加载很多次并且每次需要 30 秒。
有什么方法可以在初始化时将模型加载到 worker 上,而不是每次都在 start_bundle()
中加载模型?
提前致谢!
迪米特里
最简单的方法是添加一个 if self.model is None: self.model = load_model_from_file()
,这可能不会减少重新加载模型的次数。
这是因为 DoFn 实例目前未跨包重复使用。这意味着您的模型将在每个工作项执行后被遗忘。
您还可以在保存模型的地方创建一个 global
变量。这将减少重新加载的数量,但这确实是非正统的(尽管它可能会解决您的用例)。
全局变量方法应该像这样工作:
class MyModelDoFn(object):
def process(self, elem):
global my_model
if my_model is None:
my_model = load_model_from_file()
yield my_model.apply_to(elem)
依赖线程局部变量的方法看起来像这样。考虑到这将为每个线程加载一次模型,因此加载模型的次数取决于运行器实现(它将在 Dataflow 中工作):
class MyModelDoFn(object):
_thread_local = threading.local()
@property
def model(self):
model = getattr(MyModelDoFn._thread_local, 'model', None)
if not model:
MyModelDoFn._thread_local.model = load_model_from_file()
return MyModelDoFn._thread_local.model
def process(self, elem):
yield self.model.apply_to(elem)
我想您也可以从 start_bundle
调用中加载模型。
注意:这种方法非常不正统,不能保证在新版本中工作,也不能保证所有运行程序。
我正在尝试结合使用 Dataflow 和 Tensorflow 进行预测。这些预测发生在工人身上,我目前正在通过 startup_bundle()
加载模型。喜欢这里:
class PredictDoFn(beam.DoFn):
def start_bundle(self):
self.model = load_model_from_file()
def process(self, element):
...
我目前的问题是,即使我处理了 1000 个元素,startup_bundle()
函数也会被调用多次(至少 10 次),而不是像我希望的那样每次工作一次。这会显着减慢管道速度,因为模型需要加载很多次并且每次需要 30 秒。
有什么方法可以在初始化时将模型加载到 worker 上,而不是每次都在 start_bundle()
中加载模型?
提前致谢! 迪米特里
最简单的方法是添加一个 if self.model is None: self.model = load_model_from_file()
,这可能不会减少重新加载模型的次数。
这是因为 DoFn 实例目前未跨包重复使用。这意味着您的模型将在每个工作项执行后被遗忘。
您还可以在保存模型的地方创建一个 global
变量。这将减少重新加载的数量,但这确实是非正统的(尽管它可能会解决您的用例)。
全局变量方法应该像这样工作:
class MyModelDoFn(object):
def process(self, elem):
global my_model
if my_model is None:
my_model = load_model_from_file()
yield my_model.apply_to(elem)
依赖线程局部变量的方法看起来像这样。考虑到这将为每个线程加载一次模型,因此加载模型的次数取决于运行器实现(它将在 Dataflow 中工作):
class MyModelDoFn(object):
_thread_local = threading.local()
@property
def model(self):
model = getattr(MyModelDoFn._thread_local, 'model', None)
if not model:
MyModelDoFn._thread_local.model = load_model_from_file()
return MyModelDoFn._thread_local.model
def process(self, elem):
yield self.model.apply_to(elem)
我想您也可以从 start_bundle
调用中加载模型。
注意:这种方法非常不正统,不能保证在新版本中工作,也不能保证所有运行程序。