尝试使用 Cloud TPU 恢复更新的 BERT 模型检查点时出现 InfeedEnqueueTuple 问题
InfeedEnqueueTuple issue when trying to restore updated BERT model checkpoint using Cloud TPU
对于以下任何帮助,我将不胜感激,在此先感谢您。我复制了 Google Bert's notebook on fine-tuning 并使用 Cloud TPU 和 Bucket 在其上训练了 SQUAD 数据集。开发集上的预测没问题,所以我在本地下载了检查点、model.ckpt.meta、model.ckpt.index和model.ckpt.data文件,并尝试使用代码恢复:
sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
saver = tf.train.import_meta_graph(META_FILE) # META_FILE being path to .meta
saver.restore(sess, 'model.ckpt')
但是,我得到了错误:
op_def = op_dict[node.op]
KeyError: 'InfeedEnqueueTuple'
我假设它是 Cloud TPU Tools and I should continue on Cloud TPU, so I tried the below ( 的一部分):
# code from cells before includes
...
tf.contrib.cloud.configure_gcs(session, credentials=auth_info)
...
tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(TPU_ADDRESS)
run_config = tf.contrib.tpu.RunConfig(
cluster=tpu_cluster_resolver,
model_dir=OUTPUT_DIR,
save_checkpoints_steps=SAVE_CHECKPOINTS_STEPS,
tpu_config=tf.contrib.tpu.TPUConfig(
iterations_per_loop=ITERATIONS_PER_LOOP,
num_shards=NUM_TPU_CORES,
per_host_input_for_training=tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2))
...
问题单元格:
"""
# not valid checkpoint error. <bucket> placeholder for cloud bucket name
sess = tf.Session()
META_FILE = "gs://<bucket>/bert/models/bertsquad/model.ckpt-10949.meta"
CKPT_FILE = "gs://<bucket>/bert/models/bertsquad/model.ckpt"
saver = tf.train.import_meta_graph(META_FILE)
saver.restore(sess, CKPT_FILE)
"""
from google.cloud import storage
from tensorflow import MetaGraphDef
client = storage.Client(project="agent-helper-4a014")
bucket = client.get_bucket(<bucket>)
metafile = "bert/models/bertsquad/model.ckpt-10949.meta"
# using full path gs://<bucket>/bert/models/bertsquad doesn't work
blob = bucket.get_blob(metafile)
#blob = bucket.blob(metafile)
#model_graph = blob.download_to_filename("model.ckpt")
model_graph = blob.download_as_string()
mgd = MetaGraphDef()
mgd.ParseFromString(model_graph)
with tf.Session() as sess:
saver = tf.train.import_meta_graph(mgd, clear_devices=True)
init_checkpoint = saver.restore(sess, 'model.ckpt')
这又导致了以下错误:
InvalidArgumentError (see above for traceback): Restoring from checkpoint failed. This is most likely due to a mismatch between the current graph and the graph from the checkpoint. Please ensure that you have not altered the graph expected based on the checkpoint. Original error:
No OpKernel was registered to support Op 'InfeedEnqueueTuple' with these attrs. Registered devices: [CPU,XLA_CPU], Registered kernels:
<no registered kernels>
[[node input_pipeline_task0/while/InfeedQueue/enqueue/0 (defined at <ipython-input-67-e4b52b7b5944>:21) = InfeedEnqueueTuple[_class=["loc:@input_pipeline_task0/while/IteratorGetNext"], device_ordinal=0, dtypes=[DT_INT32, DT_INT32, DT_INT32, DT_INT32, DT_INT32, DT_INT32], shapes=[[2], [2,384], [2,384], [2,384], [2], [2]], _device="/job:worker/task:0/device:CPU:0"](input_pipeline_task0/while/IteratorGetNext, input_pipeline_task0/while/IteratorGetNext:1, input_pipeline_task0/while/IteratorGetNext:2, input_pipeline_task0/while/IteratorGetNext:3, input_pipeline_task0/while/IteratorGetNext:4, input_pipeline_task0/while/IteratorGetNext:5)]]
如果您的动机是预测,那么只需给出保存检查点和元文件的 model_dir 位置(必须是 GCS 存储桶)。该代码不会再次进行训练(因为检查点是为训练步骤数而保存的,并且模型图中没有变化)。会直接跳转到预测。
但是,如果您的用例真的想保存检查点,并且只为了推理而恢复它,那么请按照以下步骤操作:
- 像原始模型一样手动为每一层创建模型网络,或者使用保存的 .meta 文件使用
tf.train.import()
函数重新创建网络,如下所示:
saver = tf.train.import_meta_graph('<filename>.meta')
- 现在,使用以下方法恢复检查点:
saver.restore(sess, 'model.ckpt')
注意:恢复检查点的模型图应与保存这些检查点的原始图完全相同。
希望这能解决您的问题。
对于以下任何帮助,我将不胜感激,在此先感谢您。我复制了 Google Bert's notebook on fine-tuning 并使用 Cloud TPU 和 Bucket 在其上训练了 SQUAD 数据集。开发集上的预测没问题,所以我在本地下载了检查点、model.ckpt.meta、model.ckpt.index和model.ckpt.data文件,并尝试使用代码恢复:
sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
saver = tf.train.import_meta_graph(META_FILE) # META_FILE being path to .meta
saver.restore(sess, 'model.ckpt')
但是,我得到了错误:
op_def = op_dict[node.op]
KeyError: 'InfeedEnqueueTuple'
我假设它是 Cloud TPU Tools and I should continue on Cloud TPU, so I tried the below (
# code from cells before includes
...
tf.contrib.cloud.configure_gcs(session, credentials=auth_info)
...
tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(TPU_ADDRESS)
run_config = tf.contrib.tpu.RunConfig(
cluster=tpu_cluster_resolver,
model_dir=OUTPUT_DIR,
save_checkpoints_steps=SAVE_CHECKPOINTS_STEPS,
tpu_config=tf.contrib.tpu.TPUConfig(
iterations_per_loop=ITERATIONS_PER_LOOP,
num_shards=NUM_TPU_CORES,
per_host_input_for_training=tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2))
...
问题单元格:
"""
# not valid checkpoint error. <bucket> placeholder for cloud bucket name
sess = tf.Session()
META_FILE = "gs://<bucket>/bert/models/bertsquad/model.ckpt-10949.meta"
CKPT_FILE = "gs://<bucket>/bert/models/bertsquad/model.ckpt"
saver = tf.train.import_meta_graph(META_FILE)
saver.restore(sess, CKPT_FILE)
"""
from google.cloud import storage
from tensorflow import MetaGraphDef
client = storage.Client(project="agent-helper-4a014")
bucket = client.get_bucket(<bucket>)
metafile = "bert/models/bertsquad/model.ckpt-10949.meta"
# using full path gs://<bucket>/bert/models/bertsquad doesn't work
blob = bucket.get_blob(metafile)
#blob = bucket.blob(metafile)
#model_graph = blob.download_to_filename("model.ckpt")
model_graph = blob.download_as_string()
mgd = MetaGraphDef()
mgd.ParseFromString(model_graph)
with tf.Session() as sess:
saver = tf.train.import_meta_graph(mgd, clear_devices=True)
init_checkpoint = saver.restore(sess, 'model.ckpt')
这又导致了以下错误:
InvalidArgumentError (see above for traceback): Restoring from checkpoint failed. This is most likely due to a mismatch between the current graph and the graph from the checkpoint. Please ensure that you have not altered the graph expected based on the checkpoint. Original error:
No OpKernel was registered to support Op 'InfeedEnqueueTuple' with these attrs. Registered devices: [CPU,XLA_CPU], Registered kernels:
<no registered kernels>
[[node input_pipeline_task0/while/InfeedQueue/enqueue/0 (defined at <ipython-input-67-e4b52b7b5944>:21) = InfeedEnqueueTuple[_class=["loc:@input_pipeline_task0/while/IteratorGetNext"], device_ordinal=0, dtypes=[DT_INT32, DT_INT32, DT_INT32, DT_INT32, DT_INT32, DT_INT32], shapes=[[2], [2,384], [2,384], [2,384], [2], [2]], _device="/job:worker/task:0/device:CPU:0"](input_pipeline_task0/while/IteratorGetNext, input_pipeline_task0/while/IteratorGetNext:1, input_pipeline_task0/while/IteratorGetNext:2, input_pipeline_task0/while/IteratorGetNext:3, input_pipeline_task0/while/IteratorGetNext:4, input_pipeline_task0/while/IteratorGetNext:5)]]
如果您的动机是预测,那么只需给出保存检查点和元文件的 model_dir 位置(必须是 GCS 存储桶)。该代码不会再次进行训练(因为检查点是为训练步骤数而保存的,并且模型图中没有变化)。会直接跳转到预测。
但是,如果您的用例真的想保存检查点,并且只为了推理而恢复它,那么请按照以下步骤操作:
- 像原始模型一样手动为每一层创建模型网络,或者使用保存的 .meta 文件使用
tf.train.import()
函数重新创建网络,如下所示:
saver = tf.train.import_meta_graph('<filename>.meta')
- 现在,使用以下方法恢复检查点:
saver.restore(sess, 'model.ckpt')
注意:恢复检查点的模型图应与保存这些检查点的原始图完全相同。
希望这能解决您的问题。