无法腌制 keras.layers.StringLookup

unable to pickle keras.layers.StringLookup

要重现的最少代码:

import tensorflow
import pickle
print(tensorflow.__version__)
pickle.dumps(tensorflow.keras.layers.StringLookup())

输出:

2.8.0
---------------------------------------------------------------------------
InvalidArgumentError                      Traceback (most recent call last)
<ipython-input-2-caa8edb2bfc5> in <module>()
      2 import pickle
      3 print(tensorflow.__version__)
----> 4 pickle.dumps(tensorflow.keras.layers.StringLookup())

1 frames
/usr/local/lib/python3.7/dist-packages/tensorflow/python/framework/ops.py in _numpy(self)
   1189       return self._numpy_internal()
   1190     except core._NotOkStatusException as e:  # pylint: disable=protected-access
-> 1191       raise core._status_to_exception(e) from None  # pylint: disable=protected-access
   1192 
   1193   @property

InvalidArgumentError: Cannot convert a Tensor of dtype resource to a NumPy array.

这是在 tensorflow 2.8 上,但我在其他版本上也遇到了同样的问题

几个地方提到的解决方法是分别保存权重和配置

使用这个包装器 class 可以在没有太多样板的情况下做到这一点(包装器是可腌制的,因此任何包含它的对象都可以无缝序列化)

class KerasPickleWrapper:
    def __init__(self, obj=None):
        self.obj = obj

    def __getstate__(self):
        if self.obj is not None:
            return self.obj.__class__, self.obj.get_config(), self.obj.get_weights()
        else:
            return None

    def __setstate__(self, state):
        if state is not None:
            cls, config, weights = state
            self.obj = cls.from_config(config)
            self.obj.set_weights(weights)