使用合并索引 keras 进行上采样(unpooling)
Upsampling with pooling indices keras (unpooling)
首先我要说我是深度学习的新手
我正在尝试在 keras 中编写一个 segnet,它使用池索引进行上采样。
我将此函数与 Lambda 层结合使用来执行最大池化并保存池化索引:
def pool_argmax2D(x, pool_size=(2,2), strides=(2,2)):
padding = 'SAME'
pool_size = [1, pool_size[0], pool_size[1], 1]
strides = [1, strides[0], strides[1], 1]
ksize = [1, pool_size[0], pool_size[1], 1]
output, argmax = tf.nn.max_pool_with_argmax(
x,
ksize = ksize,
strides = strides,
padding = padding
)
return [output, argmax]
[...]
pool_4, mask_4 = Lambda(pool_argmax2D, arguments={'pool_size': pool_size, 'strides': pool_size})(conv_10)
[...]
似乎有效。在我的模型摘要中,它 return 是一个形状为 (None、h/2、w/2、通道) 的张量。
但是,我在查找或编写有效的 unpooling 函数时遇到了一些问题。
我无法 return 形状为 (None, 2h,2w, channels) 的张量
(None 批量大小)
我已经尝试过在 Whosebug 上找到的这些 unpooling 函数(但不仅如此):
Function1
没有结果
有人可以帮助我吗?谢谢
编辑:
这是我正在尝试使用的模型
def getSegNet3(n_ch, height , width, n_labels, pool_size=(2, 2), output_mode="sigmoid"):
# encoder
inputs = Input(shape=(n_ch, height, width))
conv_1 = Conv2D(16, (3, 3), kernel_initializer='he_normal', padding='same',data_format='channels_first')(inputs)
conv_1 = BatchNormalization(axis=1)(conv_1)
conv_1 = Activation("relu")(conv_1)
conv_2 = Conv2D(16, (3, 3), kernel_initializer='he_normal', padding='same',data_format='channels_first')(conv_1)
conv_2 = BatchNormalization(axis=1)(conv_2)
conv_2 = Activation("relu")(conv_2)
conv_2 = core.Permute((2, 3, 1))(conv_2)
pool_1, mask_1 = Lambda(pool_argmax2D, arguments={'pool_size': pool_size, 'strides': pool_size})(conv_2)
pool_1 = core.Permute((3, 1, 2))(pool_1)
conv_3 = Conv2D(32, (3, 3), kernel_initializer='he_normal', padding='same',data_format='channels_first')(pool_1)
conv_3 = BatchNormalization(axis=1)(conv_3)
conv_3 = Activation("relu")(conv_3)
conv_4 = Conv2D(32, (3, 3), kernel_initializer='he_normal', padding='same',data_format='channels_first')(conv_3)
conv_4 = BatchNormalization(axis=1)(conv_4)
conv_4 = Activation("relu")(conv_4)
conv_4 = core.Permute((2, 3, 1))(conv_4)
pool_2, mask_2 = Lambda(pool_argmax2D, arguments={'pool_size': pool_size, 'strides': pool_size})(conv_4)
pool_2 = core.Permute((3, 1, 2))(pool_2)
conv_5 = Conv2D(64, (3, 3), kernel_initializer='he_normal', padding='same',data_format='channels_first')(pool_2)
conv_5 = BatchNormalization(axis=1)(conv_5)
conv_5 = Activation("relu")(conv_5)
conv_6 = Conv2D(64, (3, 3), kernel_initializer='he_normal', padding='same',data_format='channels_first')(conv_5)
conv_6 = BatchNormalization(axis=1)(conv_6)
conv_6 = Activation("relu")(conv_6)
conv_7 = Conv2D(64, (3, 3), kernel_initializer='he_normal', padding='same',data_format='channels_first')(conv_6)
conv_7 = BatchNormalization(axis=1)(conv_7)
conv_7 = Activation("relu")(conv_7)
conv_7 = core.Permute((2, 3, 1))(conv_7)
pool_3, mask_3 = Lambda(pool_argmax2D, arguments={'pool_size': pool_size, 'strides': pool_size})(conv_7)
pool_3 = core.Permute((3, 1, 2))(pool_3)
conv_8 = Conv2D(128, (3, 3), kernel_initializer='he_normal', padding='same',data_format='channels_first')(pool_3)
conv_8 = BatchNormalization(axis=1)(conv_8)
conv_8 = Activation("relu")(conv_8)
conv_9 = Conv2D(128, (3, 3), kernel_initializer='he_normal', padding='same',data_format='channels_first')(conv_8)
conv_9 = BatchNormalization(axis=1)(conv_9)
conv_9 = Activation("relu")(conv_9)
conv_10 = Conv2D(128, (3, 3), kernel_initializer='he_normal', padding='same',data_format='channels_first')(conv_9)
conv_10 = BatchNormalization(axis=1)(conv_10)
conv_10 = Activation("relu")(conv_10)
conv_10 = core.Permute((2, 3, 1))(conv_10)
pool_4, mask_4 = Lambda(pool_argmax2D, arguments={'pool_size': pool_size, 'strides': pool_size})(conv_10)
pool_4 = core.Permute((3, 1, 2))(pool_4)
conv_11 = Conv2D(256, (3, 3), kernel_initializer='he_normal', padding='same',data_format='channels_first')(pool_4)
conv_11 = BatchNormalization(axis=1)(conv_11)
conv_11 = Activation("relu")(conv_11)
conv_12 = Conv2D(256, (3, 3), kernel_initializer='he_normal', padding='same',data_format='channels_first')(conv_11)
conv_12 = BatchNormalization(axis=1)(conv_12)
conv_12 = Activation("relu")(conv_12)
conv_13 = Conv2D(256, (3, 3), kernel_initializer='he_normal', padding='same',data_format='channels_first')(conv_12)
conv_13 = BatchNormalization(axis=1)(conv_13)
conv_13 = Activation("relu")(conv_13)
conv_13 = core.Permute((2, 3, 1))(conv_13)
pool_5, mask_5 = Lambda(pool_argmax2D, arguments={'pool_size': pool_size, 'strides': pool_size})(conv_13)
print("Build encoder done..")
# decoder
#unpool_1 = MaxUnpooling2D(pool_5, mask_5,(None,4,4,256))
unpool_1 = Lambda(unpool2D,output_shape=(4,4,256),arguments={'ind':mask_5})(pool_5)
unpool_1 = core.Permute((3, 1, 2))(unpool_1)
conv_14 = Conv2D(256, (3, 3), kernel_initializer='he_normal', padding='same',data_format='channels_first')(unpool_1)
conv_14 = BatchNormalization(axis=1)(conv_14)
conv_14 = Activation("relu")(conv_14)
conv_15 = Conv2D(256, (3, 3), kernel_initializer='he_normal', padding='same',data_format='channels_first')(conv_14)
conv_15 = BatchNormalization(axis=1)(conv_15)
conv_15 = Activation("relu")(conv_15)
conv_16 = Conv2D(256, (3, 3), kernel_initializer='he_normal', padding='same',data_format='channels_first')(conv_15)
conv_16 = BatchNormalization(axis=1)(conv_16)
conv_16 = Activation("relu")(conv_16)
conv_16 = core.Permute((2, 3, 1))(conv_16)
unpool_2 = Lambda(unpool2D,output_shape=(8,8,256),arguments={'ind':mask_4})(conv_16)
unpool_2 = core.Permute((3, 1, 2))(unpool_2)
conv_17 = Conv2D(256, (3, 3), kernel_initializer='he_normal', padding='same',data_format='channels_first')(unpool_2)
conv_17 = BatchNormalization(axis=1)(conv_17)
conv_17 = Activation("relu")(conv_17)
conv_18 = Conv2D(256, (3, 3), kernel_initializer='he_normal', padding='same',data_format='channels_first')(conv_17)
conv_18 = BatchNormalization(axis=1)(conv_18)
conv_18 = Activation("relu")(conv_18)
conv_19 = Conv2D(128, (3, 3), kernel_initializer='he_normal', padding='same',data_format='channels_first')(conv_18)
conv_19 = BatchNormalization(axis=1)(conv_19)
conv_19 = Activation("relu")(conv_19)
conv_19 = core.Permute((2, 3, 1))(conv_19)
unpool_3 = Lambda(unpool2D,output_shape=(16,16,128),arguments={'ind':mask_3})(conv_19)
unpool_3 = core.Permute((3, 1, 2))(unpool_3)
conv_20 = Conv2D(128, (3, 3), kernel_initializer='he_normal', padding='same',data_format='channels_first')(unpool_3)
conv_20 = BatchNormalization(axis=1)(conv_20)
conv_20 = Activation("relu")(conv_20)
conv_21 = Conv2D(128, (3, 3), kernel_initializer='he_normal', padding='same',data_format='channels_first')(conv_20)
conv_21 = BatchNormalization(axis=1)(conv_21)
conv_21 = Activation("relu")(conv_21)
conv_22 = Conv2D(64, (3, 3), kernel_initializer='he_normal', padding='same',data_format='channels_first')(conv_21)
conv_22 = BatchNormalization(axis=1)(conv_22)
conv_22 = Activation("relu")(conv_22)
conv_22 = core.Permute((2, 3, 1))(conv_22)
unpool_4 = Lambda(unpool2D,output_shape=(32,32,64),arguments={'ind':mask_2})(conv_22)
unpool_4 = core.Permute((3, 1, 2))(unpool_4)
conv_23 = Conv2D(64, (3, 3), kernel_initializer='he_normal', padding='same',data_format='channels_first')(unpool_4)
conv_23 = BatchNormalization(axis=1)(conv_23)
conv_23 = Activation("relu")(conv_23)
conv_24 = Conv2D(32, (3, 3), kernel_initializer='he_normal', padding='same',data_format='channels_first')(conv_23)
conv_24 = BatchNormalization(axis=1)(conv_24)
conv_24 = Activation("relu")(conv_24)
conv_24 = core.Permute((2, 3, 1))(conv_24)
unpool_5 = Lambda(unpool2D,output_shape=(64,64,32),arguments{'ind':mask_1})(conv_24)
unpool_5 = core.Permute((3, 1, 2))(unpool_5)
conv_25 = Conv2D(32, (3, 3), kernel_initializer='he_normal', padding='same',data_format='channels_first')(unpool_5)
conv_25 = BatchNormalization(axis=1)(conv_25)
conv_25 = Activation("relu")(conv_25)
conv_26 = Convolution2D(n_labels, (1, 1), padding="valid", data_format="channels_first")(conv_25)
conv_26 = BatchNormalization(axis=1)(conv_26)
outputs = Activation(output_mode)(conv_26)
print("Build decoder done..")
model = Model(inputs=inputs, outputs=outputs, name="SegNet")
return model
我正在尝试使用的功能:
def unpool2D(pool, ind, ksize=(2,2)):
with tf.compat.v1.variable_scope("unpool"):
input_shape = tf.shape(pool)
output_shape = [input_shape[0],
input_shape[1] * ksize[0],
input_shape[2] * ksize[1],
input_shape[3]]
flat_input_size = tf.math.cumprod(input_shape)[-1]
flat_output_shape = tf.cast([output_shape[0],
output_shape[1] * output_shape[2] * output_shape[3]], tf.int64)
pool_ = tf.reshape(pool, [flat_input_size])
batch_range = tf.reshape(tf.range(tf.cast(output_shape[0], tf.int64), dtype=tf.int64),
shape=[input_shape[0], 1, 1, 1])
b = tf.ones_like(ind) * batch_range
b = tf.reshape(b, [flat_input_size, 1])
ind_ = tf.reshape(ind, [flat_input_size, 1]) % flat_output_shape[1]
ind_ = tf.concat([b, ind_], 1)
ret = tf.scatter_nd(ind_, pool_, shape=flat_output_shape)
ret = tf.reshape(ret, output_shape)
return ret
这是我得到的:
~/bones-adamo/models.py in getSegNet3(n_ch, height, width, n_labels, pool_size, output_mode)
1013 unpool_1 = core.Permute((3, 1, 2))(unpool_1)
1014
-> 1015 conv_14 = Conv2D(256, (3, 3), kernel_initializer='he_normal', padding='same',data_format='channels_first')(unpool_1)
1016 conv_14 = BatchNormalization(axis=1)(conv_14)
1017 conv_14 = Activation("relu")(conv_14)
~/venv/lib/python3.8/site-packages/tensorflow/python/keras/engine/base_layer.py in __call__(self, *args, **kwargs)
923 # >> model = tf.keras.Model(inputs, outputs)
924 if _in_functional_construction_mode(self, inputs, args, kwargs, input_list):
--> 925 return self._functional_construction_call(inputs, args, kwargs,
926 input_list)
927
~/venv/lib/python3.8/site-packages/tensorflow/python/keras/engine/base_layer.py in _functional_construction_call(self, inputs, args, kwargs, input_list)
1096 # Build layer if applicable (if the `build` method has been
1097 # overridden).
-> 1098 self._maybe_build(inputs)
1099 cast_inputs = self._maybe_cast_inputs(inputs, input_list)
1100
~/venv/lib/python3.8/site-packages/tensorflow/python/keras/engine/base_layer.py in _maybe_build(self, inputs)
2641 # operations.
2642 with tf_utils.maybe_init_scope(self):
-> 2643 self.build(input_shapes) # pylint:disable=not-callable
2644 # We must set also ensure that the layer is marked as built, and the build
2645 # shape is stored since user defined build functions may not be calling
~/venv/lib/python3.8/site-packages/tensorflow/python/keras/layers/convolutional.py in build(self, input_shape)
185 def build(self, input_shape):
186 input_shape = tensor_shape.TensorShape(input_shape)
--> 187 input_channel = self._get_input_channel(input_shape)
188 if input_channel % self.groups != 0:
189 raise ValueError(
~/venv/lib/python3.8/site-packages/tensorflow/python/keras/layers/convolutional.py in _get_input_channel(self, input_shape)
357 channel_axis = self._get_channel_axis()
358 if input_shape.dims[channel_axis].value is None:
--> 359 raise ValueError('The channel dimension of the inputs '
360 'should be defined. Found `None`.')
361 return int(input_shape[channel_axis])
ValueError: The channel dimension of the inputs should be defined. Found `None`.
好的,我解决了我的问题。有一个模型架构问题我第一时间没有发现。
如果你想使用池索引来上采样,我建议你使用这些自定义层 here.
class MaxUnpooling2D(Layer):
def __init__(self, size=(2, 2), **kwargs):
super(MaxUnpooling2D, self).__init__(**kwargs)
self.size = size
def call(self, inputs, output_shape=None):
updates, mask = inputs[0], inputs[1]
with tf.compat.v1.variable_scope(self.name):
mask = K.cast(mask, 'int32')
input_shape = tf.shape(updates, out_type='int32')
#print(updates.shape)
#print(mask.shape)
if output_shape is None:
output_shape = (
input_shape[0],
input_shape[1] * self.size[0],
input_shape[2] * self.size[1],
input_shape[3])
ret = tf.scatter_nd(K.expand_dims(K.flatten(mask)),
K.flatten(updates),
[K.prod(output_shape)])
input_shape = updates.shape
out_shape = [-1,
input_shape[1] * self.size[0],
input_shape[2] * self.size[1],
input_shape[3]]
return K.reshape(ret, out_shape)
def get_config(self):
config = super().get_config().copy()
config.update({
'size': self.size
})
return config
def compute_output_shape(self, input_shape):
mask_shape = input_shape[1]
return (
mask_shape[0],
mask_shape[1]*self.size[0],
mask_shape[2]*self.size[1],
mask_shape[3]
)
用法示例:
unpool_3 = MaxUnpooling2D()([conv_19,mask_3])
我添加了 get_config 以避免此错误:
NotImplementedError: Layer MaxPoolingWithArgmax2D has arguments in `__init__` and therefore must override `get_config`.
希望这个回答能对其他用户有所帮助
首先我要说我是深度学习的新手
我正在尝试在 keras 中编写一个 segnet,它使用池索引进行上采样。
我将此函数与 Lambda 层结合使用来执行最大池化并保存池化索引:
def pool_argmax2D(x, pool_size=(2,2), strides=(2,2)):
padding = 'SAME'
pool_size = [1, pool_size[0], pool_size[1], 1]
strides = [1, strides[0], strides[1], 1]
ksize = [1, pool_size[0], pool_size[1], 1]
output, argmax = tf.nn.max_pool_with_argmax(
x,
ksize = ksize,
strides = strides,
padding = padding
)
return [output, argmax]
[...]
pool_4, mask_4 = Lambda(pool_argmax2D, arguments={'pool_size': pool_size, 'strides': pool_size})(conv_10)
[...]
似乎有效。在我的模型摘要中,它 return 是一个形状为 (None、h/2、w/2、通道) 的张量。 但是,我在查找或编写有效的 unpooling 函数时遇到了一些问题。 我无法 return 形状为 (None, 2h,2w, channels) 的张量 (None 批量大小)
我已经尝试过在 Whosebug 上找到的这些 unpooling 函数(但不仅如此):
Function1
没有结果
有人可以帮助我吗?谢谢
编辑: 这是我正在尝试使用的模型
def getSegNet3(n_ch, height , width, n_labels, pool_size=(2, 2), output_mode="sigmoid"):
# encoder
inputs = Input(shape=(n_ch, height, width))
conv_1 = Conv2D(16, (3, 3), kernel_initializer='he_normal', padding='same',data_format='channels_first')(inputs)
conv_1 = BatchNormalization(axis=1)(conv_1)
conv_1 = Activation("relu")(conv_1)
conv_2 = Conv2D(16, (3, 3), kernel_initializer='he_normal', padding='same',data_format='channels_first')(conv_1)
conv_2 = BatchNormalization(axis=1)(conv_2)
conv_2 = Activation("relu")(conv_2)
conv_2 = core.Permute((2, 3, 1))(conv_2)
pool_1, mask_1 = Lambda(pool_argmax2D, arguments={'pool_size': pool_size, 'strides': pool_size})(conv_2)
pool_1 = core.Permute((3, 1, 2))(pool_1)
conv_3 = Conv2D(32, (3, 3), kernel_initializer='he_normal', padding='same',data_format='channels_first')(pool_1)
conv_3 = BatchNormalization(axis=1)(conv_3)
conv_3 = Activation("relu")(conv_3)
conv_4 = Conv2D(32, (3, 3), kernel_initializer='he_normal', padding='same',data_format='channels_first')(conv_3)
conv_4 = BatchNormalization(axis=1)(conv_4)
conv_4 = Activation("relu")(conv_4)
conv_4 = core.Permute((2, 3, 1))(conv_4)
pool_2, mask_2 = Lambda(pool_argmax2D, arguments={'pool_size': pool_size, 'strides': pool_size})(conv_4)
pool_2 = core.Permute((3, 1, 2))(pool_2)
conv_5 = Conv2D(64, (3, 3), kernel_initializer='he_normal', padding='same',data_format='channels_first')(pool_2)
conv_5 = BatchNormalization(axis=1)(conv_5)
conv_5 = Activation("relu")(conv_5)
conv_6 = Conv2D(64, (3, 3), kernel_initializer='he_normal', padding='same',data_format='channels_first')(conv_5)
conv_6 = BatchNormalization(axis=1)(conv_6)
conv_6 = Activation("relu")(conv_6)
conv_7 = Conv2D(64, (3, 3), kernel_initializer='he_normal', padding='same',data_format='channels_first')(conv_6)
conv_7 = BatchNormalization(axis=1)(conv_7)
conv_7 = Activation("relu")(conv_7)
conv_7 = core.Permute((2, 3, 1))(conv_7)
pool_3, mask_3 = Lambda(pool_argmax2D, arguments={'pool_size': pool_size, 'strides': pool_size})(conv_7)
pool_3 = core.Permute((3, 1, 2))(pool_3)
conv_8 = Conv2D(128, (3, 3), kernel_initializer='he_normal', padding='same',data_format='channels_first')(pool_3)
conv_8 = BatchNormalization(axis=1)(conv_8)
conv_8 = Activation("relu")(conv_8)
conv_9 = Conv2D(128, (3, 3), kernel_initializer='he_normal', padding='same',data_format='channels_first')(conv_8)
conv_9 = BatchNormalization(axis=1)(conv_9)
conv_9 = Activation("relu")(conv_9)
conv_10 = Conv2D(128, (3, 3), kernel_initializer='he_normal', padding='same',data_format='channels_first')(conv_9)
conv_10 = BatchNormalization(axis=1)(conv_10)
conv_10 = Activation("relu")(conv_10)
conv_10 = core.Permute((2, 3, 1))(conv_10)
pool_4, mask_4 = Lambda(pool_argmax2D, arguments={'pool_size': pool_size, 'strides': pool_size})(conv_10)
pool_4 = core.Permute((3, 1, 2))(pool_4)
conv_11 = Conv2D(256, (3, 3), kernel_initializer='he_normal', padding='same',data_format='channels_first')(pool_4)
conv_11 = BatchNormalization(axis=1)(conv_11)
conv_11 = Activation("relu")(conv_11)
conv_12 = Conv2D(256, (3, 3), kernel_initializer='he_normal', padding='same',data_format='channels_first')(conv_11)
conv_12 = BatchNormalization(axis=1)(conv_12)
conv_12 = Activation("relu")(conv_12)
conv_13 = Conv2D(256, (3, 3), kernel_initializer='he_normal', padding='same',data_format='channels_first')(conv_12)
conv_13 = BatchNormalization(axis=1)(conv_13)
conv_13 = Activation("relu")(conv_13)
conv_13 = core.Permute((2, 3, 1))(conv_13)
pool_5, mask_5 = Lambda(pool_argmax2D, arguments={'pool_size': pool_size, 'strides': pool_size})(conv_13)
print("Build encoder done..")
# decoder
#unpool_1 = MaxUnpooling2D(pool_5, mask_5,(None,4,4,256))
unpool_1 = Lambda(unpool2D,output_shape=(4,4,256),arguments={'ind':mask_5})(pool_5)
unpool_1 = core.Permute((3, 1, 2))(unpool_1)
conv_14 = Conv2D(256, (3, 3), kernel_initializer='he_normal', padding='same',data_format='channels_first')(unpool_1)
conv_14 = BatchNormalization(axis=1)(conv_14)
conv_14 = Activation("relu")(conv_14)
conv_15 = Conv2D(256, (3, 3), kernel_initializer='he_normal', padding='same',data_format='channels_first')(conv_14)
conv_15 = BatchNormalization(axis=1)(conv_15)
conv_15 = Activation("relu")(conv_15)
conv_16 = Conv2D(256, (3, 3), kernel_initializer='he_normal', padding='same',data_format='channels_first')(conv_15)
conv_16 = BatchNormalization(axis=1)(conv_16)
conv_16 = Activation("relu")(conv_16)
conv_16 = core.Permute((2, 3, 1))(conv_16)
unpool_2 = Lambda(unpool2D,output_shape=(8,8,256),arguments={'ind':mask_4})(conv_16)
unpool_2 = core.Permute((3, 1, 2))(unpool_2)
conv_17 = Conv2D(256, (3, 3), kernel_initializer='he_normal', padding='same',data_format='channels_first')(unpool_2)
conv_17 = BatchNormalization(axis=1)(conv_17)
conv_17 = Activation("relu")(conv_17)
conv_18 = Conv2D(256, (3, 3), kernel_initializer='he_normal', padding='same',data_format='channels_first')(conv_17)
conv_18 = BatchNormalization(axis=1)(conv_18)
conv_18 = Activation("relu")(conv_18)
conv_19 = Conv2D(128, (3, 3), kernel_initializer='he_normal', padding='same',data_format='channels_first')(conv_18)
conv_19 = BatchNormalization(axis=1)(conv_19)
conv_19 = Activation("relu")(conv_19)
conv_19 = core.Permute((2, 3, 1))(conv_19)
unpool_3 = Lambda(unpool2D,output_shape=(16,16,128),arguments={'ind':mask_3})(conv_19)
unpool_3 = core.Permute((3, 1, 2))(unpool_3)
conv_20 = Conv2D(128, (3, 3), kernel_initializer='he_normal', padding='same',data_format='channels_first')(unpool_3)
conv_20 = BatchNormalization(axis=1)(conv_20)
conv_20 = Activation("relu")(conv_20)
conv_21 = Conv2D(128, (3, 3), kernel_initializer='he_normal', padding='same',data_format='channels_first')(conv_20)
conv_21 = BatchNormalization(axis=1)(conv_21)
conv_21 = Activation("relu")(conv_21)
conv_22 = Conv2D(64, (3, 3), kernel_initializer='he_normal', padding='same',data_format='channels_first')(conv_21)
conv_22 = BatchNormalization(axis=1)(conv_22)
conv_22 = Activation("relu")(conv_22)
conv_22 = core.Permute((2, 3, 1))(conv_22)
unpool_4 = Lambda(unpool2D,output_shape=(32,32,64),arguments={'ind':mask_2})(conv_22)
unpool_4 = core.Permute((3, 1, 2))(unpool_4)
conv_23 = Conv2D(64, (3, 3), kernel_initializer='he_normal', padding='same',data_format='channels_first')(unpool_4)
conv_23 = BatchNormalization(axis=1)(conv_23)
conv_23 = Activation("relu")(conv_23)
conv_24 = Conv2D(32, (3, 3), kernel_initializer='he_normal', padding='same',data_format='channels_first')(conv_23)
conv_24 = BatchNormalization(axis=1)(conv_24)
conv_24 = Activation("relu")(conv_24)
conv_24 = core.Permute((2, 3, 1))(conv_24)
unpool_5 = Lambda(unpool2D,output_shape=(64,64,32),arguments{'ind':mask_1})(conv_24)
unpool_5 = core.Permute((3, 1, 2))(unpool_5)
conv_25 = Conv2D(32, (3, 3), kernel_initializer='he_normal', padding='same',data_format='channels_first')(unpool_5)
conv_25 = BatchNormalization(axis=1)(conv_25)
conv_25 = Activation("relu")(conv_25)
conv_26 = Convolution2D(n_labels, (1, 1), padding="valid", data_format="channels_first")(conv_25)
conv_26 = BatchNormalization(axis=1)(conv_26)
outputs = Activation(output_mode)(conv_26)
print("Build decoder done..")
model = Model(inputs=inputs, outputs=outputs, name="SegNet")
return model
我正在尝试使用的功能:
def unpool2D(pool, ind, ksize=(2,2)):
with tf.compat.v1.variable_scope("unpool"):
input_shape = tf.shape(pool)
output_shape = [input_shape[0],
input_shape[1] * ksize[0],
input_shape[2] * ksize[1],
input_shape[3]]
flat_input_size = tf.math.cumprod(input_shape)[-1]
flat_output_shape = tf.cast([output_shape[0],
output_shape[1] * output_shape[2] * output_shape[3]], tf.int64)
pool_ = tf.reshape(pool, [flat_input_size])
batch_range = tf.reshape(tf.range(tf.cast(output_shape[0], tf.int64), dtype=tf.int64),
shape=[input_shape[0], 1, 1, 1])
b = tf.ones_like(ind) * batch_range
b = tf.reshape(b, [flat_input_size, 1])
ind_ = tf.reshape(ind, [flat_input_size, 1]) % flat_output_shape[1]
ind_ = tf.concat([b, ind_], 1)
ret = tf.scatter_nd(ind_, pool_, shape=flat_output_shape)
ret = tf.reshape(ret, output_shape)
return ret
这是我得到的:
~/bones-adamo/models.py in getSegNet3(n_ch, height, width, n_labels, pool_size, output_mode)
1013 unpool_1 = core.Permute((3, 1, 2))(unpool_1)
1014
-> 1015 conv_14 = Conv2D(256, (3, 3), kernel_initializer='he_normal', padding='same',data_format='channels_first')(unpool_1)
1016 conv_14 = BatchNormalization(axis=1)(conv_14)
1017 conv_14 = Activation("relu")(conv_14)
~/venv/lib/python3.8/site-packages/tensorflow/python/keras/engine/base_layer.py in __call__(self, *args, **kwargs)
923 # >> model = tf.keras.Model(inputs, outputs)
924 if _in_functional_construction_mode(self, inputs, args, kwargs, input_list):
--> 925 return self._functional_construction_call(inputs, args, kwargs,
926 input_list)
927
~/venv/lib/python3.8/site-packages/tensorflow/python/keras/engine/base_layer.py in _functional_construction_call(self, inputs, args, kwargs, input_list)
1096 # Build layer if applicable (if the `build` method has been
1097 # overridden).
-> 1098 self._maybe_build(inputs)
1099 cast_inputs = self._maybe_cast_inputs(inputs, input_list)
1100
~/venv/lib/python3.8/site-packages/tensorflow/python/keras/engine/base_layer.py in _maybe_build(self, inputs)
2641 # operations.
2642 with tf_utils.maybe_init_scope(self):
-> 2643 self.build(input_shapes) # pylint:disable=not-callable
2644 # We must set also ensure that the layer is marked as built, and the build
2645 # shape is stored since user defined build functions may not be calling
~/venv/lib/python3.8/site-packages/tensorflow/python/keras/layers/convolutional.py in build(self, input_shape)
185 def build(self, input_shape):
186 input_shape = tensor_shape.TensorShape(input_shape)
--> 187 input_channel = self._get_input_channel(input_shape)
188 if input_channel % self.groups != 0:
189 raise ValueError(
~/venv/lib/python3.8/site-packages/tensorflow/python/keras/layers/convolutional.py in _get_input_channel(self, input_shape)
357 channel_axis = self._get_channel_axis()
358 if input_shape.dims[channel_axis].value is None:
--> 359 raise ValueError('The channel dimension of the inputs '
360 'should be defined. Found `None`.')
361 return int(input_shape[channel_axis])
ValueError: The channel dimension of the inputs should be defined. Found `None`.
好的,我解决了我的问题。有一个模型架构问题我第一时间没有发现。 如果你想使用池索引来上采样,我建议你使用这些自定义层 here.
class MaxUnpooling2D(Layer):
def __init__(self, size=(2, 2), **kwargs):
super(MaxUnpooling2D, self).__init__(**kwargs)
self.size = size
def call(self, inputs, output_shape=None):
updates, mask = inputs[0], inputs[1]
with tf.compat.v1.variable_scope(self.name):
mask = K.cast(mask, 'int32')
input_shape = tf.shape(updates, out_type='int32')
#print(updates.shape)
#print(mask.shape)
if output_shape is None:
output_shape = (
input_shape[0],
input_shape[1] * self.size[0],
input_shape[2] * self.size[1],
input_shape[3])
ret = tf.scatter_nd(K.expand_dims(K.flatten(mask)),
K.flatten(updates),
[K.prod(output_shape)])
input_shape = updates.shape
out_shape = [-1,
input_shape[1] * self.size[0],
input_shape[2] * self.size[1],
input_shape[3]]
return K.reshape(ret, out_shape)
def get_config(self):
config = super().get_config().copy()
config.update({
'size': self.size
})
return config
def compute_output_shape(self, input_shape):
mask_shape = input_shape[1]
return (
mask_shape[0],
mask_shape[1]*self.size[0],
mask_shape[2]*self.size[1],
mask_shape[3]
)
用法示例:
unpool_3 = MaxUnpooling2D()([conv_19,mask_3])
我添加了 get_config 以避免此错误:
NotImplementedError: Layer MaxPoolingWithArgmax2D has arguments in `__init__` and therefore must override `get_config`.
希望这个回答能对其他用户有所帮助