Arbitrary Image Stylization 模块 Colab 示例错误

Arbitrary Image Stylization module Colab example error

在尝试 运行 Colab 上托管的 Tensorflow 团队的任意图像样式化示例代码时,我一直收到此错误。

这是代码。 (如 this notebook 中所见,块 5 给出了错误)。

from __future__ import absolute_import, division, print_function

import functools
import os

from matplotlib import gridspec
import matplotlib.pylab as plt
import numpy as np
import tensorflow as tf
import tensorflow_hub as hub

print("TF Version: ", tf.__version__)
print("TF-Hub version: ", hub.__version__)
print("Eager mode enabled: ", tf.executing_eagerly())
print("GPU available: ", tf.test.is_gpu_available())

# @title Define image loading and visualization functions  { display-mode: "form" }

def crop_center(image):
  """Returns a cropped square image."""
  shape = image.shape
  new_shape = min(shape[1], shape[2])
  offset_y = max(shape[1] - shape[2], 0) // 2
  offset_x = max(shape[2] - shape[1], 0) // 2
  image = tf.image.crop_to_bounding_box(
      image, offset_y, offset_x, new_shape, new_shape)
  return image

@functools.lru_cache(maxsize=None)
def load_image(image_url, image_size=(256, 256), preserve_aspect_ratio=True):
  """Loads and preprocesses images."""
  # Cache image file locally.
  image_path = tf.keras.utils.get_file(os.path.basename(image_url)[-128:], image_url)
  # Load and convert to float32 numpy array, add batch dimension, and normalize to range [0, 1].
  img = plt.imread(image_path).astype(np.float32)[np.newaxis, ...]
  if img.max() > 1.0:
    img = img / 255.
  if len(img.shape) == 3:
    img = tf.stack([img, img, img], axis=-1)
  img = crop_center(img)
  img = tf.image.resize(img, image_size, preserve_aspect_ratio=True)
  return img

def show_n(images, titles=('',)):
  n = len(images)
  image_sizes = [image.shape[1] for image in images]
  w = (image_sizes[0] * 6) // 320
  plt.figure(figsize=(w  * n, w))
  gs = gridspec.GridSpec(1, n, width_ratios=image_sizes)
  for i in range(n):
    plt.subplot(gs[i])
    plt.imshow(images[i][0], aspect='equal')
    plt.axis('off')
    plt.title(titles[i] if len(titles) > i else '')
  plt.show()

# @title Load example images  { display-mode: "form" }

content_image_url = 'https://upload.wikimedia.org/wikipedia/commons/thumb/f/fd/Golden_Gate_Bridge_from_Battery_Spencer.jpg/640px-Golden_Gate_Bridge_from_Battery_Spencer.jpg'  # @param {type:"string"}
style_image_url = 'https://upload.wikimedia.org/wikipedia/commons/0/0a/The_Great_Wave_off_Kanagawa.jpg'  # @param {type:"string"}
output_image_size = 384  # @param {type:"integer"}

# The content image size can be arbitrary.
content_img_size = (output_image_size, output_image_size)
# The style prediction model was trained with image size 256 and it's the 
# recommended image size for the style image (though, other sizes work as 
# well but will lead to different results).
style_img_size = (256, 256)  # Recommended to keep it at 256.

content_image = load_image(content_image_url, content_img_size)
style_image = load_image(style_image_url, style_img_size)
style_image = tf.nn.avg_pool(style_image, ksize=[3,3], strides=[1,1], padding='SAME')
show_n([content_image, style_image], ['Content image', 'Style image'])

这是错误信息:

TypeError                                 Traceback (most recent call last)
<ipython-input-8-b21290c301e4> in <module>()
     14 style_image = load_image(style_image_url, style_img_size)
     15 style_image = tf.nn.avg_pool(style_image, ksize=[3,3], strides=[1,1], padding='SAME')
---> 16 show_n([content_image, style_image], ['Content image', 'Style image'])

3 frames
<ipython-input-3-a1ddf5894992> in show_n(images, titles)
     29   image_sizes = [image.shape[1] for image in images]
     30   w = (image_sizes[0] * 6) // 320
---> 31   plt.figure(figsize=(w  * n, w))
     32   gs = gridspec.GridSpec(1, n, width_ratios=image_sizes)
     33   for i in range(n):

/usr/local/lib/python3.6/dist-packages/matplotlib/pyplot.py in figure(num, figsize, dpi, facecolor, edgecolor, frameon, FigureClass, clear, **kwargs)
    544                                         frameon=frameon,
    545                                         FigureClass=FigureClass,
--> 546                                         **kwargs)
    547 
    548         if figLabel:

/usr/local/lib/python3.6/dist-packages/matplotlib/backend_bases.py in new_figure_manager(cls, num, *args, **kwargs)
   3322         from matplotlib.figure import Figure
   3323         fig_cls = kwargs.pop('FigureClass', Figure)
-> 3324         fig = fig_cls(*args, **kwargs)
   3325         return cls.new_figure_manager_given_figure(num, fig)
   3326 

/usr/local/lib/python3.6/dist-packages/matplotlib/figure.py in __init__(self, figsize, dpi, facecolor, edgecolor, linewidth, frameon, subplotpars, tight_layout, constrained_layout)
    346             frameon = rcParams['figure.frameon']
    347 
--> 348         if not np.isfinite(figsize).all() or (np.array(figsize) <= 0).any():
    349             raise ValueError('figure size must be positive finite not '
    350                              f'{figsize}')

TypeError: ufunc 'isfinite' not supported for the input types, and the inputs could not be safely coerced to any supported types according to the casting rule ''safe''

谁能解释一下问题是什么以及解决方法?提前致谢。

使用您在 Google 托管的 Colab 中提供的代码,您也已链接。
我只是在 selects Tensorflow 版本 2.x.

笔记本的开头添加了一行代码

%tensorflow_version 2.x

版本 2.x 表示 Google Colab 将 select 最新稳定的 Tensorflow 版本可用。