在 Caffe 中绘制网络导致 pydot 抛出 End of Line 错误

Drawing network in Caffe causes pydot to throw End of Line errors

所以我只是从master分支中拉取了Caffe的最新版本,并完成了所有的初始化步骤。作为快速测试,我尝试 运行 提供的 python/draw_net.py 脚本,以便可视化 MNIST 自动编码器示例网络。 在执行以下命令时:

./python/draw_net.py examples/mnist/mnist_autoencoder.prototxt trial_viz.png

Pydot 抱怨,并抛出以下错误:

Drawing net to trial_viz.png
Traceback (most recent call last):
  File "./python/draw_net.py", line 44, in <module>
    main()
  File "./python/draw_net.py", line 40, in main
    caffe.draw.draw_net_to_file(net, args.output_image_file, args.rankdir)
  File "/home/username/3rdparty/caffe/python/caffe/draw.py", line 165, in draw_net_to_file
    fid.write(draw_net(caffe_net, rankdir, ext))
  File "/home/username/3rdparty/caffe/python/caffe/draw.py", line 156, in draw_net
    return get_pydot_graph(caffe_net, rankdir).create(format=ext)
  File "/usr/lib/pymodules/python2.7/pydot.py", line 1796, in create
    status, stderr_output) )
pydot.InvocationException: Program terminated with status: 1. stderr follows: Warning: /tmp/tmpjqPQBC:5: string ran past end of line
Error: /tmp/tmpjqPQBC:6: syntax error near line 6
context:  >>> ( <<< Sigmoid)" [shape=record, style=filled, fillcolor="#6495ED"];
Warning: /tmp/tmpjqPQBC:6: ambiguous "6495ED" splits into two names: "6495" and "ED"
Warning: /tmp/tmpjqPQBC:6: string ran past end of line
Warning: /tmp/tmpjqPQBC:9: string ran past end of line
Warning: /tmp/tmpjqPQBC:10: string ran past end of line
Warning: /tmp/tmpjqPQBC:12: string ran past end of line
Warning: /tmp/tmpjqPQBC:13: ambiguous "6495ED" splits into two names: "6495" and "ED"
Warning: /tmp/tmpjqPQBC:13: string ran past end of line
Warning: /tmp/tmpjqPQBC:14: string ran past end of line
Warning: /tmp/tmpjqPQBC:15: string ran past end of line
Warning: /tmp/tmpjqPQBC:17: string ran past end of line
Warning: /tmp/tmpjqPQBC:18: ambiguous "6495ED" splits into two names: "6495" and "ED"

我看到了更多 Warning 消息,如上所示,我的错误日志变得太大,所以我没有 post 整个日志。 This post,似乎看到了和我一样的错误,所以我尝试复制他们的解决方案,并将 draw.pyget_pydot_graph() 方法中的所有字符串更改为原始字符串。但这似乎没有用。

关于如何解决这个问题有什么建议吗?

谢谢!! :)

我认为关键在于determine_node_label_by_layertype函数。这是一段代码,看起来应该像这样(或者至少在我当前版本的存储库中是这样):

def determine_node_label_by_layertype(layer, layertype, rankdir):
"""Define node label based on layer type
"""

    if rankdir in ('TB', 'BT'):
        # If graph orientation is vertical, horizontal space is free and
        # vertical space is not; separate words with spaces
        separator = ' '
    else:
        # If graph orientation is horizontal, vertical space is free and
        # horizontal space is not; separate words with newlines
        separator = '\n'

separater = '\n' 替换为 separater = r"\n",它似乎对我有用。