在 tensorflow js 中编写自定义 InstantLayerNormalization

Writing custom InstantLayerNormalization in tensorflow js

我正在尝试在浏览器中实现深度学习模型,这需要移植一些自定义层,其中之一是即时层规范化。在应该工作但有点旧的代码段下方。 我收到此错误:

Uncaught (in promise) ReferenceError: initializer is not defined at InstantLayerNormalization.build

<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs/dist/tf.min.js"> </script>
<script>
class InstantLayerNormalization extends tf.layers.Layer
{
    static className = 'InstantLayerNormalization';
    epsilon = 1e-7 
    gamma;
    beta;
    constructor(config) 
    {
        super(config);
    }
    getConfig() 
    {
        const config = super.getConfig();
        return config;
    }
    
    build(input_shape)
    {
        let shape = tf.tensor(input_shape);
        // initialize gamma
        self.gamma = self.add_weight(shape=shape, 
                                     initializer='ones', 
                                     trainable=true, 
                                     name='gamma')
        // initialize beta
        self.beta = self.add_weight(shape=shape,
                            initializer='zeros',
                            trainable=true,
                            name='beta')
    }        

    call(inputs){
        mean = tf.math.reduce_mean(inputs, axis=[-1], keepdims=True)
        variance = tf.math.reduce_mean(tf.math.square(inputs - mean), axis=[-1], keepdims=True)
        std = tf.math.sqrt(variance + self.epsilon)
        outputs = (inputs - mean) / std
        outputs = outputs * self.gamma
        outputs = outputs + self.beta
        return outputs
    }
    static get className() {
        console.log(className);
       return className;
    }
}

tf.serialization.registerClass(InstantLayerNormalization);
</script>

继承classtf.layers.Layer的方法调用不正确

  • self 在 python 中是 this 在 js
  • add_weight 更像是 addWeight
  • HereaddWeight 方法的签名。请注意,在 js 中,函数参数解构赋值
  • 没有格式 variable=value
// instead of this
self.gamma = self.add_weight(shape=shape, initializer='ones', trainable=true, name='gamma')
// it should rather be
this.gamma = this.addWeight('gamma', shape, undefined, 'ones', undefined, true)