TensorFlow 联合压缩:如何实现用于 TFF build_federated_averaging_process 的有状态编码器?

TensorFlow Federated Compression: How to implement a stateful encoder to be used in in TFF's build_federated_averaging_process?

在 Tensorflow Federated (TFF) 中,您可以向 tff.learning.build_federated_averaging_process 传递一个 broadcast_process 和一个 aggregation_process,它们可以嵌入自定义编码器,例如应用自定义压缩。

说到我的问题,我正在尝试实现一个编码器来稀疏模型 updates/model 权重。

我正在尝试通过实施 tensorflow_model_optimization.python.core.internal 中的 EncodingStageInterface 来构建这样的编码器。 但是,我正在努力实现一个(本地)状态来逐轮累积模型 updates/model 权重的归零坐标。请注意,不应传达此状态,只需要在本地维护(因此 AdaptiveEncodingStageInterface 应该没有帮助)。一般来说,问题是如何在编码器中维护本地状态,然后将其传递给 fedavg 进程。

我附上了我的编码器实现的代码(除了我想添加的状态之外,它可以像预期的那样在无状态下正常工作)。 然后附上我使用编码器实现的代码摘录。 如果我取消注释 stateful_encoding_stage_topk.py 中的注释部分,代码将不起作用:我不知道如何在 TF 非急切模式下管理状态(即张量)。

stateful_encoding_stage_topk.py

import tensorflow as tf
import numpy as np
from tensorflow_model_optimization.python.core.internal import tensor_encoding as te


@te.core.tf_style_encoding_stage
class StatefulTopKEncodingStage(te.core.EncodingStageInterface):

  ENCODED_VALUES_KEY = 'stateful_topk_values'
  INDICES_KEY = 'indices'
  
  
  def __init__(self):
    super().__init__()
    # Here I would like to init my state
    #self.A = tf.zeros([800], dtype=tf.float32)

  @property
  def name(self):
    """See base class."""
    return 'stateful_topk'

  @property
  def compressible_tensors_keys(self):
    """See base class."""
    return [self.ENCODED_VALUES_KEY]

  @property
  def commutes_with_sum(self):
    """See base class."""
    return True

  @property
  def decode_needs_input_shape(self):
    """See base class."""
    return True

  def get_params(self):
    """See base class."""
    return {}, {}

  def encode(self, x, encode_params):
    """See base class."""
    del encode_params  # Unused.

    dW = tf.reshape(x, [-1])
    # Here I would like to retrieve the state
    A = tf.zeros([800], dtype=tf.float32)
    #A = self.residual
    
    dW_and_A = tf.math.add(A, dW)

    percentage = tf.constant(0.4, dtype=tf.float32)
    k_float = tf.multiply(percentage, tf.cast(tf.size(dW), tf.float32))
    k_int = tf.cast(tf.math.round(k_float), dtype=tf.int32)

    values, indices = tf.math.top_k(tf.math.abs(dW_and_A), k = k_int, sorted = False)
    indices = tf.expand_dims(indices, 1)
    sparse_dW = tf.scatter_nd(indices, values, tf.shape(dW_and_A))
    
    # Here I would like to update the state
    A_updated = tf.math.subtract(dW_and_A, sparse_dW)
    #self.A = A_updated
    
    encoded_x = {self.ENCODED_VALUES_KEY: values,
                 self.INDICES_KEY: indices}

    return encoded_x

  def decode(self,
             encoded_tensors,
             decode_params,
             num_summands=None,
             shape=None):
    """See base class."""
    del decode_params, num_summands  # Unused.
    
    indices = encoded_tensors[self.INDICES_KEY]
    values = encoded_tensors[self.ENCODED_VALUES_KEY]
    tensor = tf.fill([800], 0.0)
    decoded_values = tf.tensor_scatter_nd_update(tensor, indices, values)
    
    return tf.reshape(decoded_values, shape)



def sparse_quantizing_encoder():
  encoder = te.core.EncoderComposer(
      StatefulTopKEncodingStage() )  
  return encoder.make()

fedavg_with_sparsification.py

[...]

def sparsification_broadcast_encoder_fn(value):
  spec = tf.TensorSpec(value.shape, value.dtype)
  return te.encoders.as_simple_encoder(te.encoders.identity(), spec)

def sparsification_mean_encoder_fn(value):
  spec = tf.TensorSpec(value.shape, value.dtype)
  
  if value.shape.num_elements() == 800:
    return te.encoders.as_gather_encoder(
        stateful_encoding_stage_topk.sparse_quantizing_encoder(), spec)

  else:
    return te.encoders.as_gather_encoder(te.encoders.identity(), spec)
  
encoded_broadcast_process = (
    tff.learning.framework.build_encoded_broadcast_process_from_model(
        model_fn, sparsification_broadcast_encoder_fn))

encoded_mean_process = (
    tff.learning.framework.build_encoded_mean_process_from_model(
        model_fn, sparsification_mean_encoder_fn))


iterative_process = tff.learning.build_federated_averaging_process(
    model_fn,
    client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.004),
    client_weight_fn=lambda _: tf.constant(1.0),
    broadcast_process=encoded_broadcast_process,
    aggregation_process=encoded_mean_process)

[...]

我正在使用:

我会尝试分两部分回答; (1) top_k 没有状态的编码器和 (2) 在 TFF 中实现你似乎想要的有状态的想法。

(1)

为了让 TopKEncodingStage 在没有状态的情况下工作,我看到了一些需要更改的细节。

commutes_with_sum 属性 应设置为 False。在伪代码中,它的意思是是否 sum_x(decode(encode(x))) == decode(sum_x(encode(x))) 。这对于您的 encode 方法 returns 的表示是不正确的——对 indices 求和效果不佳。我认为 decode 方法的实现可以简化为

return tf.scatter_nd(
    indices=encoded_tensors[self.INDICES_KEY],
    updates=encoded_tensors[self.ENCODED_VALUES_KEY],
    shape=shape)

(2)

使用 tff.learning.build_federated_averaging_process 无法以这种方式实现您所指的内容。此方法返回的进程没有任何维护 client/local 状态的机制。无论您 StatefulTopKEncodingStage 中表达的状态是什么,最终都会成为服务器状态,而不是本地状态。

要使用 client/local 状态,您可能需要编写更多自定义代码。对于入门者,请参阅 examples/stateful_clients,您可以对其进行调整以存储您引用的状态。

请记住,在 TFF 中,这需要表示为函数转换。将值存储在 class 的属性中并在其他地方使用它们可能会导致令人惊讶的错误。