Keras 中 Kronecker 产品的自定义 Lambda 层 - 为 batch_size 保留的维度存在问题
Custom Lambda layer for Kronecker product in Keras - troubles with the dimension reserved for batch_size
我正在使用带有 Tensorflow 后端的 Keras 2.1.5 创建图像分类模型。
在我的模型中,我想通过计算 Kronecker product 来组合卷积层的输入和输出。我已经使用 Keras 后端函数编写了计算两个 3D 张量的克罗内克积的函数。
def kronecker_product(mat1, mat2):
#Computes the Kronecker product of two matrices.
m1, n1 = mat1.get_shape().as_list()
mat1_rsh = K.reshape(mat1, [m1, 1, n1, 1])
m2, n2 = mat2.get_shape().as_list()
mat2_rsh = K.reshape(mat2, [1, m2, 1, n2])
return K.reshape(mat1_rsh * mat2_rsh, [m1 * m2, n1 * n2])
def kronecker_product3D(tensors):
tensor1 = tensors[0]
tensor2 = tensors[1]
#Separete slices of tensor and computes appropriate matrice kronecker product
m1, n1, o1 = tensor1.get_shape().as_list()
m2, n2, o2 = tensor2.get_shape().as_list()
x_list = []
for ind1 in range(o1):
for ind2 in range(o2):
x_list.append(DenseNetKTC.kronecker_product(tensor1[:,:,ind1], tensor2[:,:,ind2]))
return K.reshape(Concatenate()(x_list), [m1 * m2, n1 * n2, o1 * o2])
然后我尝试使用 Lambda 层将操作包装到 Keras 层中:
cb = Convolution2D(12, (3,3), padding='same')(x)
x = Lambda(kronecker_product3D)([x, cb])
但收到错误“ValueError:要解压的值太多(预期为 3)”。我希望输入是 3 维的张量,但实际上它有 4 维——第一个维是为 Keras 中的 batch_size 保留的。我不知道如何处理动态大小的第四个维度。
我搜索了很多,但找不到任何手动处理批次维度的示例函数。
我很乐意提供任何提示或帮助。非常感谢!
简单的解决方法:
只需将批量维度添加到您的计算和重塑中
def kronecker_product(mat1, mat2):
#Computes the Kronecker product of two matrices.
batch, m1, n1 = mat1.get_shape().as_list()
mat1_rsh = K.reshape(mat1, [-1, m1, 1, n1, 1])
batch, m2, n2 = mat2.get_shape().as_list()
mat2_rsh = K.reshape(mat2, [-1, 1, m2, 1, n2])
return K.reshape(mat1_rsh * mat2_rsh, [-1, m1 * m2, n1 * n2])
def kronecker_product3D(tensors):
tensor1 = tensors[0]
tensor2 = tensors[1]
#Separete slices of tensor and computes appropriate matrice kronecker product
batch, m1, n1, o1 = tensor1.get_shape().as_list()
batch, m2, n2, o2 = tensor2.get_shape().as_list()
x_list = []
for ind1 in range(o1):
for ind2 in range(o2):
x_list.append(kronecker_product(tensor1[:,:,:,ind1], tensor2[:,:,:,ind2]))
return K.reshape(Concatenate()(x_list), [-1, m1 * m2, n1 * n2, o1 * o2])
对于困难的解决方案,我会尝试找出一种避免迭代的方法,但这可能比我想象的要复杂得多....
我正在使用带有 Tensorflow 后端的 Keras 2.1.5 创建图像分类模型。 在我的模型中,我想通过计算 Kronecker product 来组合卷积层的输入和输出。我已经使用 Keras 后端函数编写了计算两个 3D 张量的克罗内克积的函数。
def kronecker_product(mat1, mat2):
#Computes the Kronecker product of two matrices.
m1, n1 = mat1.get_shape().as_list()
mat1_rsh = K.reshape(mat1, [m1, 1, n1, 1])
m2, n2 = mat2.get_shape().as_list()
mat2_rsh = K.reshape(mat2, [1, m2, 1, n2])
return K.reshape(mat1_rsh * mat2_rsh, [m1 * m2, n1 * n2])
def kronecker_product3D(tensors):
tensor1 = tensors[0]
tensor2 = tensors[1]
#Separete slices of tensor and computes appropriate matrice kronecker product
m1, n1, o1 = tensor1.get_shape().as_list()
m2, n2, o2 = tensor2.get_shape().as_list()
x_list = []
for ind1 in range(o1):
for ind2 in range(o2):
x_list.append(DenseNetKTC.kronecker_product(tensor1[:,:,ind1], tensor2[:,:,ind2]))
return K.reshape(Concatenate()(x_list), [m1 * m2, n1 * n2, o1 * o2])
然后我尝试使用 Lambda 层将操作包装到 Keras 层中:
cb = Convolution2D(12, (3,3), padding='same')(x)
x = Lambda(kronecker_product3D)([x, cb])
但收到错误“ValueError:要解压的值太多(预期为 3)”。我希望输入是 3 维的张量,但实际上它有 4 维——第一个维是为 Keras 中的 batch_size 保留的。我不知道如何处理动态大小的第四个维度。
我搜索了很多,但找不到任何手动处理批次维度的示例函数。
我很乐意提供任何提示或帮助。非常感谢!
简单的解决方法:
只需将批量维度添加到您的计算和重塑中
def kronecker_product(mat1, mat2):
#Computes the Kronecker product of two matrices.
batch, m1, n1 = mat1.get_shape().as_list()
mat1_rsh = K.reshape(mat1, [-1, m1, 1, n1, 1])
batch, m2, n2 = mat2.get_shape().as_list()
mat2_rsh = K.reshape(mat2, [-1, 1, m2, 1, n2])
return K.reshape(mat1_rsh * mat2_rsh, [-1, m1 * m2, n1 * n2])
def kronecker_product3D(tensors):
tensor1 = tensors[0]
tensor2 = tensors[1]
#Separete slices of tensor and computes appropriate matrice kronecker product
batch, m1, n1, o1 = tensor1.get_shape().as_list()
batch, m2, n2, o2 = tensor2.get_shape().as_list()
x_list = []
for ind1 in range(o1):
for ind2 in range(o2):
x_list.append(kronecker_product(tensor1[:,:,:,ind1], tensor2[:,:,:,ind2]))
return K.reshape(Concatenate()(x_list), [-1, m1 * m2, n1 * n2, o1 * o2])
对于困难的解决方案,我会尝试找出一种避免迭代的方法,但这可能比我想象的要复杂得多....