如何正确实施过滤 data_inputs 的 Neuraxle 流水线步骤?

How to correctly implement a Neuraxle pipeline step that filters data_inputs?

我正在尝试在 neuraxle (0.5.2) 中实现一个 BaseStep 来过滤 data_input(并相应地 expected_output)。

class DataFrameQuery(NonFittableMixin, InputAndOutputTransformerMixin, BaseStep):
    def __init__(self, query):
        super().__init__()
        self.query = query
    
    def transform(self, data_input):
        data_input, expected_output = data_input
        # verify that input and output are either pd.DataFrame or pd.Series
        # ... [redacted] ...
        new_data_input = data_input.query(self.query)
        if all(output is None for output in expected_output):
            new_expected_output = [None] * len(new_data_input)
        else:
            new_expected_output = expected_output.loc[new_data_input.index]
        return new_data_input, new_expected_output

这自然(在大多数情况下)会导致 len(data_inputs)(和 expected_outputs)发生变化。在 neuraxle 的最新版本中,我得到一个 AssertionError:

data_input = pd.DataFrame([{"A": 1, "B": 1}, {"A": 2, "B": 2}], index=[1, 2])
expected_output = pd.Series([1, 2], index=[1, 2])
pipeline = Pipeline([
    DataFrameQuery("A == 1")
])
pipeline.fit_transform(data_input, expected_output)
AssertionError: InputAndOutputTransformerMixin: 
    Caching broken because there is a different len of current ids, and data inputs. 
    Please use InputAndOutputTransformerWrapper if you plan to change the len of the data inputs.

根据我的理解,这就是 Neuraxle 的 Handler Methods 应该发挥作用的地方。然而,到目前为止,我还没有找到一个可以让我在转换后更新输入和输出的 current_ids 的工具(我想它应该是 _did_transform,但这似乎不是接到电话)。

一般:

编辑:我也尝试设置 savers 并按照 here 所述使用 InputAndOutputTransformerWrapper。仍然收到以下错误(可能是因为我不确定在哪里调用 handle_transform):

AssertionError: InputAndOutputTransformerWrapper: 
    Caching broken because there is a different len of current ids, and data inputs.
    Please resample the current ids using handler methods, or create new ones by setting the wrapped step saver to HashlibMd5ValueHasher using the BaseStep.set_savers method.

编辑:目前我已经解决了如下问题:


class OutputShapeChangingStep(NonFittableMixin, InputAndOutputTransformerMixin, BaseStep):
    def __init__(self, idx):
        super().__init__()
        self.idx = idx
        
    def _update_data_container_shape(self, data_container):
        assert len(data_container.expected_outputs) == len(data_container.data_inputs)
        data_container.set_current_ids(range(len(data_container.data_inputs)))
        data_container = self.hash_data_container(data_container)
        return data_container
    
    def _set_data_inputs_and_expected_outputs(self, data_container, new_inputs, new_expected_outputs) -> DataContainer:
        data_container.set_data_inputs(new_inputs)
        data_container.set_expected_outputs(new_expected_outputs)
        data_container = self._update_data_container_shape(data_container)
        return data_container
    
    def transform(self, data_inputs):
        data_inputs, expected_outputs = data_inputs
        return data_inputs[self.idx], expected_outputs[self.idx]

在这种情况下,我很可能“错误地”覆盖了 InputAndOutputTransformerMixin_set_data_inputs_and_expected_outputs_transform_data_container 会是更好的选择吗?),但像这样更新 current_ids(并重新散列容器)似乎是可能的。但是,我仍然对如何更符合 Neuraxle 的 API 期望感兴趣。

就个人而言,我最喜欢的方法是只使用处理程序方法。我认为它更干净。

处理程序方法的使用示例:

class WindowTimeSeries(ForceHandleMixin, BaseTransformer):
   def __init__(self):
      BaseTransformer.__init__(self)
      ForceHandleMixin.__init__(self)

   def _transform_data_container(self, data_container: DataContainer, context: ExecutionContext) -> DataContainer:
      di = data_container.data_inputs
      new_di, new_eo = np.array_split(np.array(di), 2)

      return DataContainer(
        summary_id=data_container.summary_id,
        data_inputs=new_di,
        expected_outputs=new_eo
      )

这样,将重新创建当前 ID,并使用默认行为对其进行哈希处理。注意:摘要 id 是最重要的。它是在开始时创建的,并使用超参数重新散列...如果需要,您还可以使用自定义保护程序(如 HashlibMd5ValueHasher)生成新的当前 ID。

编辑,确实有bug。这是固定在这里:https://github.com/Neuraxio/Neuraxle/pull/379

用法示例:

step = InputAndOutputTransformerWrapper(WindowTimeSeriesForOutputTransformerWrapper()) \
    .set_hashers([HashlibMd5ValueHasher()])
step = StepThatInheritsFromInputAndOutputTransformerMixin() \
     .set_hashers([HashlibMd5ValueHasher()])