通过混合模式覆盖方法无法按预期工作

Overwriting methods via mixin pattern does not work as intended

我正在尝试为 problem 引入 mod/mixin。我在这里特别关注 SpeechRecognitionProblem。我打算 mod 解决这个问题,因此我试图做到以下几点:

class SpeechRecognitionProblemMod(speech_recognition.SpeechRecognitionProblem):

    def hparams(self, defaults, model_hparams):
        SpeechRecognitionProblem.hparams(self, defaults, model_hparams)
        vocab_size = self.feature_encoders(model_hparams.data_dir)['targets'].vocab_size
        p = defaults
        p.vocab_size['targets'] = vocab_size

    def feature_encoders(self, data_dir): 
        # ...

所以这个没什么用。它从基础 class 调用 hparams() 函数,然后更改一些值。

现在,已经有一些现成的问题,例如Libri 演讲:

@registry.register_problem()
class Librispeech(speech_recognition.SpeechRecognitionProblem):
    # ..

但是,为了应用我的 mod化验,我正在这样做:

@registry.register_problem()
class LibrispeechMod(SpeechRecognitionProblemMod, Librispeech):
    # ..

如果我没记错的话,这应该会覆盖 Librispeech 中的所有内容(具有相同的签名),而不是调用 SpeechRecognitionProblemMod.

的函数

因为我能够使用此代码训练 model,所以我假设它到目前为止按预期工作。

我的问题来了:

训练后我想序列化 model。这通常有效。然而,它不适合我的 mod 我真的知道为什么:

在某个时刻 hparams() 被调用。调试到那个点会显示以下内容:

self                  # {LibrispeechMod}
self.hparams          # <bound method SpeechRecognitionProblem.hparams of ..>
self.feature_encoders # <bound method SpeechRecognitionProblemMod.feature_encoders of ..>

self.hparams应该是<bound method SpeechRecognitionProblemMod.hparams of ..>!似乎由于某种原因 SpeechRecognitionProblemhparams() 被直接调用而不是 SpeechRecognitionProblemMod。但是 请注意 它是 feature_encoders()!

的正确类型

问题是我知道这在训练期间有效。我可以看到相应地应用了超参数 (hparams),因为 model 的图形节点名称通过我的 modifications 发生了变化。

我需要指出一个专业。 tensor2tensor 允许动态加载 t2t_usr_dir,这是由 import_usr_dir 加载的附加 python mod 规则。我也在我的序列化脚本中使用了该函数:

if usr_dir:
    logging.info('Loading user dir %s' % usr_dir)
    import_usr_dir(usr_dir)

这可能是目前我能看到的唯一罪魁祸首,尽管我无法说出为什么这会导致问题。

如果有人看到我没有看到的东西,我很乐意得到提示,提示我这里做错了什么。


那么您遇到的错误是什么?

为了完整起见,这是调用了错误的 hparams() 方法的结果:

NotFoundError (see above for traceback): Restoring from checkpoint failed.
Key transformer/symbol_modality_256_256/softmax/weights_0 not found in checkpoint

symbol_modality_256_256 是错误的。它应该是 symbol_modality_<vocab-size>_256,其中 <vocab-size> 是在 SpeechRecognitionProblemMod.hparams 中设置的词汇量大小。

所以,这种奇怪的行为是因为我在远程调试并且 usr_dir 的源文件没有正确同步。一切都按预期工作,但源文件不匹配。

案件结案。