网络内部输入的加权和
Weighted sum of an input inside the network
我有一个包含多个输入的网络,我拆分出前 10 个输入并计算加权和,然后将其与其余输入连接起来:
first = Lambda(lambda z: z[:, 0:11])(d_inputs)
wsum_first = Lambda(calcWSumF)(first )
d_input = concatenate([d_inputs, wsum_first], axis=-1)
函数定义为:
w_vec = K.constant(np.array([range(10)]*64).reshape(10, 64)) # batch size is 64
def calcWSumF(x):
y = K.dot(w_vec, x)
y = K.expand_dims(y, -1)
return y
我想要一个常数向量来计算输入第一部分的加权和。连接不起作用,因为形状不匹配。我该如何正确实施?
您可以使用 K.sum
并且仅使用包含系数的向量来更好地编写此代码。此外,不需要使用固定的批量大小(可以是任何数字):
def calcWSumF(x, idx):
w_vec = K.constant(np.arange(idx))
y = K.sum(x[:, 0:idx] * w_vec, axis=-1, keepdims=True)
return y
d_inputs = Input((15,))
wsum_first = Lambda(calcWSumF, arguments={'idx': 10})(d_inputs)
d_input = concatenate([d_inputs, wsum_first], axis=-1)
model = Model(d_inputs, d_input)
model.predict(np.arange(15).reshape(1, 15))
# output:
array([[ 0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10.,
11., 12., 13., 14., 285.]], dtype=float32)
# Note: 0*0 + 1*1 + 2*2 + ... + 9*9 = 285
请注意,为了使其更通用,我们向 lambda 函数添加了另一个参数 (idx
),它指定了我们要考虑的元素数量。
我有一个包含多个输入的网络,我拆分出前 10 个输入并计算加权和,然后将其与其余输入连接起来:
first = Lambda(lambda z: z[:, 0:11])(d_inputs)
wsum_first = Lambda(calcWSumF)(first )
d_input = concatenate([d_inputs, wsum_first], axis=-1)
函数定义为:
w_vec = K.constant(np.array([range(10)]*64).reshape(10, 64)) # batch size is 64
def calcWSumF(x):
y = K.dot(w_vec, x)
y = K.expand_dims(y, -1)
return y
我想要一个常数向量来计算输入第一部分的加权和。连接不起作用,因为形状不匹配。我该如何正确实施?
您可以使用 K.sum
并且仅使用包含系数的向量来更好地编写此代码。此外,不需要使用固定的批量大小(可以是任何数字):
def calcWSumF(x, idx):
w_vec = K.constant(np.arange(idx))
y = K.sum(x[:, 0:idx] * w_vec, axis=-1, keepdims=True)
return y
d_inputs = Input((15,))
wsum_first = Lambda(calcWSumF, arguments={'idx': 10})(d_inputs)
d_input = concatenate([d_inputs, wsum_first], axis=-1)
model = Model(d_inputs, d_input)
model.predict(np.arange(15).reshape(1, 15))
# output:
array([[ 0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10.,
11., 12., 13., 14., 285.]], dtype=float32)
# Note: 0*0 + 1*1 + 2*2 + ... + 9*9 = 285
请注意,为了使其更通用,我们向 lambda 函数添加了另一个参数 (idx
),它指定了我们要考虑的元素数量。