如何帮助 tqdm 计算出自定义迭代器中的总数

How to help tqdm figure out the total in a custom iterator

我正在实现自己的迭代器。 tqdm 不显示进度条,因为它不知道列表中元素的总数。我不想使用 "total=",因为它看起来很难看。相反,我更愿意向我的迭代器添加一些东西,tqdm 可以使用它来计算总数。

class Batches:
    def __init__(self, batches, target_input):
        self.batches = batches
        self.pos = 0
        self.target_input = target_input

    def __iter__(self):
        return self

    def __next__(self):
        if self.pos < len(self.batches):
            minibatch = self.batches[self.pos]
            target = minibatch[:, :, self.target_input]
            self.pos += 1
            return minibatch, target
        else:
            raise StopIteration

    def __len__(self):
        return self.batches.len()

这可能吗?在上面的代码中添加什么...

像下面这样使用 tqdm..

for minibatch, target in tqdm(Batches(test, target_input)):

    output = lstm(minibatch)
    loss = criterion(output, target)
    writer.add_scalar('loss', loss, tensorboard_step)

我知道已经有一段时间了,但我一直在寻找相同的答案,这里是解决方案。而不是像这样用 tqdm 包装你的迭代

for i in tqdm(my_iterable):
    do_something()

改用 "with" 关闭,如:

with tqdm(total=len(my_iterable)) as progress_bar:
    for i in my_iterable:
        do_something()
        progress_bar.update(1) # update progress

对于您的批次,您可以将总数设置为批次数,并更新为 1(如上所述)。或者,您可以将 total 设置为实际的项目总数,将 update 设置为当前处理的批次的大小。

原来的问题是这样的:

I don't want to use "total=" as it looks ugly. Rather I would prefer to add something to my iterator that tqdm can use to figure out the total.

但是, 明确声明要使用 total:

with tqdm(total=len(my_iterable)) as progress_bar:

事实上,给定的例子比它需要的更复杂,因为原始问题没有要求复杂的柱形更新。因此,

for i in tqdm(my_iterable, total=my_total):
    do_something()

实际上已经足够了(正如作者@emem 已经在评论中指出的那样)。


这个问题相对较老(写这篇文章时已有 4 年),但查看 tqdm 的代码,可以看出 already from the very beginning(写这篇文章时已有 8 年)的行为是如果未给出 total,则默认为 total = len(iterable)

因此,问题的正确答案是实施__len__。正如问题中所述,原始示例已经实施。因此,它应该已经可以正常工作了。

可以在下面找到测试行为的完整示例(请注意 __len__ 方法上方的注释):

from time import sleep
from tqdm import tqdm


class Iter:

    def __init__(self, n=10):
        self.n = n
        self.iter = iter(range(n))

    def __iter__(self):
        return self

    def __next__(self):
        return next(self.iter)

    # commenting the next two lines disables showing the bar
    # due to tqdm not knowing the total number of elements:
    def __len__(self):
        return self.n


it = Iter()
for i in tqdm(it):
    sleep(0.2)

看看 tqdm 到底做了什么:

try:
    total = len(iterable)
except (TypeError, AttributeError):
    total = None

...并且由于我们不确切知道@Duane 将什么用作 batches,我认为这基本上只是一个隐藏得很好的错字(self.batches.len()),这会导致 AttributeError 被 tqdm 捕获。

如果 batches 只是一个序列类型,那么这可能是预期的定义:

    def __len__(self):
        return len(self.batches)

__next__的定义(使用len(self.batches))也指向这个方向