禁止 collections.defaultdict 中的键添加

Suppress key addition in collections.defaultdict

当在 defaultdict 对象中查询丢失的键时,该键自动添加 到字典中:

from collections import defaultdict

d = defaultdict(int)
res = d[5]

print(d)
# defaultdict(<class 'int'>, {5: 0})
# we want this dictionary to remain empty

但是,我们通常只想在显式或隐式分配键时才添加键:

d[8] = 1  # we want this key added
d[3] += 1 # we want this key added

一个用例是简单计数,以避免 collections.Counter 的更高开销,但通常也可能需要此功能。


反例[双关]

这是我想要的功能:

from collections import Counter
c = Counter()
res = c[5]  # 0
print(c)  # Counter()

c[8] = 1  # key added successfully
c[3] += 1 # key added successfully

但是 Counter 明显比 defaultdict(int) 慢。我发现性能下降通常比 defaultdict(int).

慢 2 倍

另外,显然Counter只能与defaultdict中的int参数相媲美,而defaultdict可以取listset


有没有办法有效地实现上述行为;例如,通过继承 defaultdict?


基准测试示例

%timeit DwD(lst)           # 72 ms
%timeit dd(lst)            # 44 ms
%timeit counter_func(lst)  # 98 ms
%timeit af(lst)            # 72 ms

测试代码:

import numpy as np
from collections import defaultdict, Counter, UserDict

class DefaultDict(defaultdict):
    def get_and_forget(self, key):
        _sentinel = object()
        value = self.get(key, _sentinel)

        if value is _sentinel:
            return self.default_factory()
        return value

class DictWithDefaults(dict):
    __slots__ = ['_factory']  # avoid using extra memory

    def __init__(self, factory, *args, **kwargs):
        self._factory = factory
        super().__init__(*args, **kwargs)

    def __missing__(self, key):
        return self._factory()

lst = np.random.randint(0, 10, 100000)

def DwD(lst):
    d = DictWithDefaults(int)
    for i in lst:
        d[i] += 1
    return d

def dd(lst):
    d = defaultdict(int)
    for i in lst:
        d[i] += 1
    return d

def counter_func(lst):
    d = Counter()
    for i in lst:
        d[i] += 1
    return d

def af(lst):
    d = DefaultDict(int)
    for i in lst:
        d[i] += 1
    return d

关于赏金评论的注意事项:

自提供赏金后已更新,因此请忽略赏金评论。

与其乱搞 collections.defaultdict 让它做我们想做的事,不如自己实现:

class DefaultDict(dict):
    def __init__(self, default_factory, **kwargs):
        super().__init__(**kwargs)

        self.default_factory = default_factory

    def __getitem__(self, key):
        try:
            return super().__getitem__(key)
        except KeyError:
            return self.default_factory()

这按你想要的方式工作:

d = DefaultDict(int)

res = d[5]
d[8] = 1 
d[3] += 1

print(d)  # {8: 1, 3: 1}

但是,对于可变类型,它可能会出现意外行为:

d = DefaultDict(list)
d[5].append('foobar')

print(d)  # output: {}

这可能就是为什么 defaultdict 在访问不存在的键时记住该值的原因。


另一种选择是扩展 defaultdict 并添加一个新方法来查找值而不记住它:

from collections import defaultdict

class DefaultDict(defaultdict):
    def get_and_forget(self, key):
        return self.get(key, self.default_factory())

请注意,get_and_forget 方法每次都会调用 default_factory(),无论该键是否已存在于字典中。如果这是不可取的,您可以使用标记值来实现它:

class DefaultDict(defaultdict):
    def get_and_forget(self, key):
        _sentinel = object()
        value = self.get(key, _sentinel)

        if value is _sentinel:
            return self.default_factory()
        return value

这对可变类型有更好的支持,因为它允许您选择是否应将值添加到字典中。

如果您只想在访问不存在的密钥时将 dict 作为默认值 return,那么您可以简单地子 class dict 并实现 __missing__:

object.__missing__(self, key)

Called by dict.__getitem__() to implement self[key] for dict subclasses when key is not in the dictionary.

看起来像这样:

class DictWithDefaults(dict):
    # not necessary, just a memory optimization
    __slots__ = ['_factory']  

    def __init__(self, factory, *args, **kwargs):
        self._factory = factory
        super().__init__(*args, **kwargs)

    def __missing__(self, key):
        return self._factory()

在这种情况下,我使用了类似 defaultdict 的方法,因此您必须传入一个 factory,它在调用时应提供默认值:

>>> dwd = DictWithDefaults(int)
>>> dwd[0]  # key does not exist
0 
>>> dwd     # key still doesn't exist
{}
>>> dwd[0] = 10
>>> dwd
{0: 10}

当您进行赋值(显式或隐式)时,值将被添加到字典中:

>>> dwd = DictWithDefaults(int)
>>> dwd[0] += 1
>>> dwd
{0: 1}

>>> dwd = DictWithDefaults(list)
>>> dwd[0] += [1]
>>> dwd
{0: [1]}

您想知道 collections.Counter 是如何做到的,从 CPython 3.6.5 开始,它还使用 __missing__:

class Counter(dict):
    ...

    def __missing__(self, key):
        'The count of elements not in the Counter is zero.'
        # Needed so that self[missing_item] does not raise KeyError
        return 0

    ...

更好的性能?!

你提到速度是一个问题,所以你可以使 class 成为 C 扩展 class(假设你使用 CPython),例如使用 Cython(我正在使用 Jupyter magic创建扩展的命令 class):

%load_ext cython

%%cython

cdef class DictWithDefaultsCython(dict):
    cdef object _factory

    def __init__(self, factory, *args, **kwargs):
        self._factory = factory
        super().__init__(*args, **kwargs)

    def __missing__(self, key):
        return self._factory()

基准

根据您的基准:

from collections import Counter, defaultdict

def d_py(lst):
    d = DictWithDefaults(int)
    for i in lst:
        d[i] += 1
    return d

def d_cy(lst):
    d = DictWithDefaultsCython(int)
    for i in lst:
        d[i] += 1
    return d

def d_dd(lst):
    d = defaultdict(int)
    for i in lst:
        d[i] += 1
    return d

鉴于这只是计算,如果不包含仅使用 Counter 初始值设定项的基准将是一个(不可原谅的)疏忽。

我最近写了一个小的基准测试工具,我认为它在这里可能会派上用场(但你也可以使用 %timeit 来实现):

from simple_benchmark import benchmark
import random

sizes = [2**i for i in range(2, 20)]
unique_lists = {i: list(range(i)) for i in sizes}
identical_lists = {i: [0]*i for i in sizes}
mixed = {i: [random.randint(0, i // 2) for j in range(i)]  for i in sizes}

functions = [d_py, d_cy, d_dd, d_c, Counter]

b_unique = benchmark(functions, unique_lists, 'list size')
b_identical = benchmark(functions, identical_lists, 'list size')
b_mixed = benchmark(functions, mixed, 'list size')

结果如下:

import matplotlib.pyplot as plt

f, (ax1, ax2, ax3) = plt.subplots(1, 3, sharey=True)
ax1.set_title('unique elements')
ax2.set_title('identical elements')
ax3.set_title('mixed elements')
b_unique.plot(ax=ax1)
b_identical.plot(ax=ax2)
b_mixed.plot(ax=ax3)

请注意,它使用对数刻度来更好地了解差异:

对于长迭代,Counter(iterable) 是迄今为止最快的。 DictWithDefaultCythondefaultdict 是相等的(大多数时候 DictWithDefault 稍微快一点,即使在这里看不到)然后是 DictWithDefault 然后是 Counter手动 for 循环。有趣的是 Counter 是最快的还是最慢的。

隐式添加 returned 值 if 是 modifie

我掩盖的事实是它与 defaultdict 有很大不同,因为需要可变类型的“只是 return 默认不保存它”:

>>> from collections import defaultdict
>>> dd = defaultdict(list)
>>> dd[0].append(10)
>>> dd
defaultdict(list, {0: [10]})

>>> dwd = DictWithDefaults(list)
>>> dwd[0].append(10)
>>> dwd
{}

这意味着当您希望修改后的值在字典中可见时,您实际上需要设置元素。

然而,这让我有些好奇,所以我想分享一种方法,你可以如何实现它(如果需要)。但这只是一个快速测试,仅适用于使用代理的 append 调用。请不要在生产代码中使用它(从我的角度来看,这只是具有娱乐价值):

from wrapt import ObjectProxy

class DictWithDefaultsFunky(dict):
    __slots__ = ['_factory']  # avoid using extra memory

    def __init__(self, factory, *args, **kwargs):
        self._factory = factory
        super().__init__(*args, **kwargs)

    def __missing__(self, key):
        ret = self._factory()
        dict_ = self

        class AppendTrigger(ObjectProxy):
            def append(self, val):
                self.__wrapped__.append(val)
                dict_[key] = ret

        return AppendTrigger(ret)

这是一个字典,return 是一个代理对象(而不是真正的默认对象),它重载了一个方法,如果调用该方法,会将 return 值添加到字典中。它“有效”:

>>> d = DictWithDefaultsFunky(list)
>>> a = d[10]
>>> d
[]

>>> a.append(1)
>>> d
{10: [1]}

但它确实有一些缺陷(可以解决,但这只是概念验证,所以我不会在这里尝试):

>>> d = DictWithDefaultsFunky(list)
>>> a = d[10]
>>> b = d[10]
>>> d
{}
>>> a.append(1)
>>> d
{10: [1]}
>>> b.append(10)
>>> d  # oups, that overwrote the previous stored value ...
{10: [10]}

如果你真的想要这样的东西,你可能需要实现一个 class 来真正跟踪值内的变化(而不仅仅是 append 调用)。

如果你想避免隐式赋值

如果您不喜欢 += 或类似操作将值添加到字典这一事实(与前面的示例相反,该示例甚至试图以非常隐含的方式添加值),那么您可能应该将其实现为方法而不是特殊方法。

例如:

class SpecialDict(dict):
    __slots__ = ['_factory']
    
    def __init__(self, factory, *args, **kwargs):
        self._factory = factory
        
    def get_or_default_from_factory(self, key):
        try:
            return self[key]
        except KeyError:
            return self._factory()
        
>>> sd = SpecialDict(int)
>>> sd.get_or_default_from_factory(0)  
0
>>> sd  
{}
>>> sd[0] = sd.get_or_default_from_factory(0) + 1
>>> sd  
{0: 1}

这类似于 Aran-Feys 回答的行为,但它使用 trycatch 方法代替 get 和哨兵。

你的悬赏信息说的是 Aran-Fey 的回答 "does not work with mutable types"。 (对于未来的读者,赏金消息是"The current answer is good, but it does not work with mutable types. If the existing answer can be adapted, or another option solution put forward, to suit this purpose, this would be ideal.")

事实是,它确实适用于可变类型:

>>> d = DefaultDict(list)
>>> d[0] += [1]
>>> d[0]
[1]
>>> d[1]
[]
>>> 1 in d
False

没有的是d[1].append(2):

>>> d[1].append(2)
>>> d[1]
[]

那是因为这不涉及对字典的存储操作。唯一涉及的字典操作是项目检索。

dict 对象在 d[1]d[1].append(2) 中看到的内容没有区别。字典不参与 append 操作。如果没有讨厌的、脆弱的堆栈检查或类似的东西,dict 就无法只为 d[1].append(2).

存储列表。

所以这是没有希望的。你应该怎么做?

好吧,一种选择是使用常规 collections.defaultdict,并且当您不想存储默认值时不使用 []。您可以使用 inget:

if key in d:
    value = d[key]
else:
    ...

value = d.get(key, sentinel)

或者,您可以在不需要时关闭默认工厂。当您有单独的 "build" 和 "read" 阶段时,这通常是合理的,并且您不希望在读取阶段使用默认工厂:

d = collections.defaultdict(list)
for thing in whatever:
    d[thing].append(other_thing)
# turn off default factory
d.default_factory = None
use(d)