在现有的 tf 模型中插入一个层,层 tf.__operators__.getitem() 处出错

Inserting a layer in an existing tf model, error at layer tf.__operators__.getitem()

我正在尝试使用 yolo v4 实现功能 visualisation/optimisation。作为 Python 中的实现,我使用此 Github repo and add a layer, which is identical to the input layer, right after the input layer. For inserting this layer I use a function that was posted (我使用接受的答案中描述的功能)。遗憾的是,这个函数并没有开箱即用,因为一些参数没有被移交给某些层,我不得不明确地移交它们。因此,我更改了最后一个 else 语句,现在有了这个代码:

def main(argv):
    source_model = helper.get_model()

    model = insert_layer_nonseq(source_model, "conv2d", opt_layer_factory, position="before", regex=False)

    model.summary()

def opt_layer_factory():
    return Conv2D(3, 1, activation='relu', input_shape=(416, 416, 3), name="opt_layer")


def insert_layer_nonseq(model, layer_regex, insert_layer_factory, insert_layer_name=None, position='after', regex=True):
    # Auxiliary dictionary to describe the network graph
    network_dict = {'input_layers_of': {}, 'new_output_tensor_of': {}}

    # Set the input layers of each layer
    for layer in model.layers:
        for node in layer._outbound_nodes:
            layer_name = node.outbound_layer.name
            if layer_name not in network_dict['input_layers_of']:
                network_dict['input_layers_of'].update(
                    {layer_name: [layer.name]})
            else:
                network_dict['input_layers_of'][layer_name].append(layer.name)

    # Set the output tensor of the input layer
    network_dict['new_output_tensor_of'].update(
        {model.layers[0].name: model.input})

    # Iterate over all layers after the input
    model_outputs = []
    for layer in model.layers[1:]:

        # Determine input tensors
        layer_input = [network_dict['new_output_tensor_of'][layer_aux]
                for layer_aux in network_dict['input_layers_of'][layer.name]]
        if len(layer_input) == 1:
            layer_input = layer_input[0]

        # Insert layer if name matches the regular expression or if names match
    regex_or_name = re.match(layer_regex, layer.name) if regex else (layer_regex == layer.name)
        if regex_or_name:
            if position == 'replace':
                x = layer_input
            elif position == 'after':
                    x = layer(layer_input)
            elif position == 'before':
                x = layer_input
            else:
                raise ValueError('position must be: before, after or replace')

            new_layer = insert_layer_factory()
       
            x = new_layer(x)
            print('New layer: {} Old layer: {} Type: {}'.format(new_layer.name,
                                                        layer.name, position))
            if position == 'before':
                x = layer(x)
        else:
            if bool(re.match(r"tf.concat*", layer.name)):
                x = layer(layer_input, -1)
            elif bool(re.match(r"tf.__operators__.getitem*", layer.name)):
                x = layer(layer_input)
            elif layer.name == "tf.nn.max_pool2d":
                x = layer(layer_input, ksize=13, padding="SAME", strides=1)
            elif layer.name == "tf.nn.max_pool2d_1":
                x = layer(layer_input, ksize=9, padding="SAME", strides=1)
            elif layer.name == "tf.nn.max_pool2d_2":
                x = layer(layer_input, ksize=5, padding="SAME", strides=1)
            elif layer.name == "tf.image.resize":
                x = layer(layer_input, size=(26, 26), method="bilinear")
            elif layer.name == "tf.image.resize_1":
                x = layer(layer_input, size=(52, 52), method="bilinear")
            elif isinstance(layer_input, list):
                x = layer(*layer_input)
            else:
                x = layer(layer_input)

        # Set new output tensor (the original one, or the one of the inserted
        # layer)
        network_dict['new_output_tensor_of'].update({layer.name: x})

        # Save tensor in output list if it is output in initial model
        if layer_name in model.output_names:
            model_outputs.append(x)

    return Model(inputs=model.inputs, outputs=model_outputs)

现在我被困在一个名为“tf.operators.getitem()”的层,它似乎需要额外的参数,但我找不到哪些。

任何帮助将不胜感激:)

我找到了一种解决方法,即不再使用正确的输入再次调用每一层。我没有使用原始模型或层,而是采用了配置,将具有所需规范的层添加到配置中,并从该配置创建了一个 tf 模型。接下来我遍历新模型的层并将权重设置为原始模型中的权重(当然期望添加的层)。
这并不能解决我原来问题的错误,但对我的项目来说已经足够了。总的来说,我认为这是向现有模型添加层时更好的解决方案。