__init__() 缺少 1 个必需的位置参数:tensorflow 中的 'sess'
__init__() missing 1 required positional argument: 'sess' in tensorflow
我正在尝试在此脚本中使用 类,它对目录 'test_images' 中的多个图像执行图像分类。我之前没有经常使用 类 所以我对如何在这种情况下正确应用它们有点困惑。错误是:TypeError: __init__() missing 1 required positional argument: 'sess'
。任何帮助将不胜感激!
下面是代码:
def image_recognition_algorithm():
def load_graph(model_file):
graph = tf.Graph()
graph_def = tf.GraphDef()
with open(model_file, "rb") as f:
graph_def.ParseFromString(f.read())
with graph.as_default():
tf.import_graph_def(graph_def)
return graph
def read_tensor_from_image_file(file_name, input_height=299, input_width=299,
input_mean=0, input_std=255):
input_name = "file_reader"
output_name = "normalized"
file_reader = tf.read_file(file_name, input_name)
image_reader = tf.image.decode_jpeg(file_reader, channels = 3, name='jpeg_reader')
float_caster = tf.cast(image_reader, tf.float32)
dims_expander = tf.expand_dims(float_caster, 0);
resized = tf.image.resize_bilinear(dims_expander, [input_height, input_width])
normalized = tf.divide(tf.subtract(resized, [input_mean]), [input_std])
sess = tf.Session()
result = sess.run(normalized)
return result
def load_labels(label_file):
label = []
proto_as_ascii_lines = tf.gfile.GFile(label_file).readlines()
for l in proto_as_ascii_lines:
label.append(l.rstrip())
return label
class initiate_session():
def __init__(self, sess):
self.sess = sess
graph = load_graph(model_file)
input_name = "import/" + input_layer
output_name = "import/" + output_layer
input_operation = graph.get_operation_by_name(input_name);
output_operation = graph.get_operation_by_name(output_name);
config = tf.ConfigProto(device_count={"CPU": 4},
inter_op_parallelism_threads=1,
intra_op_parallelism_threads=4)
self.sess = tf.Session(graph=graph, config = config)
start = time.time()
results = self.sess.run(output_operation.outputs[0],
{input_operation.outputs[0]: t})
end=time.time()
results = np.squeeze(results)
top_k = results.argsort()[-5:][::-1]
labels = load_labels(label_file)
print('\nEvaluation time (1-image): {:.3f}s\n'.format(end-start))
for i in top_k:
print(file_name, labels[i], results[i])
return [file_name] + list(results)
image_list = [f for f in listdir('test_images') if isfile(join('test_images', f))]
res_list = []
for image in image_list:
if image.lower().endswith(('.png', '.jpg', '.jpeg', '.gif')):
res_list.append(main(join('test_images', image)))
def main(self, file_name):
model_file = "tf_files/retrained_graph.pb"
label_file = "tf_files/retrained_labels.txt"
input_height = 299
input_width = 299
input_mean = 128
input_std = 128
input_layer = "Mul"
output_layer = "final_result"
t = read_tensor_from_image_file(file_name,
input_height=input_height,
input_width=input_width,
input_mean=input_mean,
input_std=input_std)
if __name__ == '__main__':
initiate_session().main()
你的initiate_session.__init__()
方法有两个参数,self
和sess
,它作为对自身的引用自动传入,sess
,你需要传入。当你实例化时initiate_session
这里:
if __name__ == '__main__':
initiate_session().main()
您需要传入一个 sess
参数。
然而,在你的情况下,我认为你想要做的实际上是删除 __init__()
方法的 sess
参数,因为你稍后要分配给 self.sess
在构造函数中,这里:
self.sess = tf.Session(graph=graph, config = config)
删除 __init__()
的 sess
参数和行
self.sess = sess
应该可以解决你的问题。
我正在尝试在此脚本中使用 类,它对目录 'test_images' 中的多个图像执行图像分类。我之前没有经常使用 类 所以我对如何在这种情况下正确应用它们有点困惑。错误是:TypeError: __init__() missing 1 required positional argument: 'sess'
。任何帮助将不胜感激!
下面是代码:
def image_recognition_algorithm():
def load_graph(model_file):
graph = tf.Graph()
graph_def = tf.GraphDef()
with open(model_file, "rb") as f:
graph_def.ParseFromString(f.read())
with graph.as_default():
tf.import_graph_def(graph_def)
return graph
def read_tensor_from_image_file(file_name, input_height=299, input_width=299,
input_mean=0, input_std=255):
input_name = "file_reader"
output_name = "normalized"
file_reader = tf.read_file(file_name, input_name)
image_reader = tf.image.decode_jpeg(file_reader, channels = 3, name='jpeg_reader')
float_caster = tf.cast(image_reader, tf.float32)
dims_expander = tf.expand_dims(float_caster, 0);
resized = tf.image.resize_bilinear(dims_expander, [input_height, input_width])
normalized = tf.divide(tf.subtract(resized, [input_mean]), [input_std])
sess = tf.Session()
result = sess.run(normalized)
return result
def load_labels(label_file):
label = []
proto_as_ascii_lines = tf.gfile.GFile(label_file).readlines()
for l in proto_as_ascii_lines:
label.append(l.rstrip())
return label
class initiate_session():
def __init__(self, sess):
self.sess = sess
graph = load_graph(model_file)
input_name = "import/" + input_layer
output_name = "import/" + output_layer
input_operation = graph.get_operation_by_name(input_name);
output_operation = graph.get_operation_by_name(output_name);
config = tf.ConfigProto(device_count={"CPU": 4},
inter_op_parallelism_threads=1,
intra_op_parallelism_threads=4)
self.sess = tf.Session(graph=graph, config = config)
start = time.time()
results = self.sess.run(output_operation.outputs[0],
{input_operation.outputs[0]: t})
end=time.time()
results = np.squeeze(results)
top_k = results.argsort()[-5:][::-1]
labels = load_labels(label_file)
print('\nEvaluation time (1-image): {:.3f}s\n'.format(end-start))
for i in top_k:
print(file_name, labels[i], results[i])
return [file_name] + list(results)
image_list = [f for f in listdir('test_images') if isfile(join('test_images', f))]
res_list = []
for image in image_list:
if image.lower().endswith(('.png', '.jpg', '.jpeg', '.gif')):
res_list.append(main(join('test_images', image)))
def main(self, file_name):
model_file = "tf_files/retrained_graph.pb"
label_file = "tf_files/retrained_labels.txt"
input_height = 299
input_width = 299
input_mean = 128
input_std = 128
input_layer = "Mul"
output_layer = "final_result"
t = read_tensor_from_image_file(file_name,
input_height=input_height,
input_width=input_width,
input_mean=input_mean,
input_std=input_std)
if __name__ == '__main__':
initiate_session().main()
你的initiate_session.__init__()
方法有两个参数,self
和sess
,它作为对自身的引用自动传入,sess
,你需要传入。当你实例化时initiate_session
这里:
if __name__ == '__main__':
initiate_session().main()
您需要传入一个 sess
参数。
然而,在你的情况下,我认为你想要做的实际上是删除 __init__()
方法的 sess
参数,因为你稍后要分配给 self.sess
在构造函数中,这里:
self.sess = tf.Session(graph=graph, config = config)
删除 __init__()
的 sess
参数和行
self.sess = sess
应该可以解决你的问题。