如何帮助 tqdm 计算出自定义迭代器中的总数
How to help tqdm figure out the total in a custom iterator
我正在实现自己的迭代器。 tqdm 不显示进度条,因为它不知道列表中元素的总数。我不想使用 "total=",因为它看起来很难看。相反,我更愿意向我的迭代器添加一些东西,tqdm 可以使用它来计算总数。
class Batches:
def __init__(self, batches, target_input):
self.batches = batches
self.pos = 0
self.target_input = target_input
def __iter__(self):
return self
def __next__(self):
if self.pos < len(self.batches):
minibatch = self.batches[self.pos]
target = minibatch[:, :, self.target_input]
self.pos += 1
return minibatch, target
else:
raise StopIteration
def __len__(self):
return self.batches.len()
这可能吗?在上面的代码中添加什么...
像下面这样使用 tqdm..
for minibatch, target in tqdm(Batches(test, target_input)):
output = lstm(minibatch)
loss = criterion(output, target)
writer.add_scalar('loss', loss, tensorboard_step)
我知道已经有一段时间了,但我一直在寻找相同的答案,这里是解决方案。而不是像这样用 tqdm 包装你的迭代
for i in tqdm(my_iterable):
do_something()
改用 "with" 关闭,如:
with tqdm(total=len(my_iterable)) as progress_bar:
for i in my_iterable:
do_something()
progress_bar.update(1) # update progress
对于您的批次,您可以将总数设置为批次数,并更新为 1(如上所述)。或者,您可以将 total 设置为实际的项目总数,将 update 设置为当前处理的批次的大小。
原来的问题是这样的:
I don't want to use "total=" as it looks ugly. Rather I would prefer to add something to my iterator that tqdm can use to figure out the total.
但是, 明确声明要使用 total
:
with tqdm(total=len(my_iterable)) as progress_bar:
事实上,给定的例子比它需要的更复杂,因为原始问题没有要求复杂的柱形更新。因此,
for i in tqdm(my_iterable, total=my_total):
do_something()
实际上已经足够了(正如作者@emem 已经在评论中指出的那样)。
这个问题相对较老(写这篇文章时已有 4 年),但查看 tqdm 的代码,可以看出 already from the very beginning(写这篇文章时已有 8 年)的行为是如果未给出 total
,则默认为 total = len(iterable)
。
因此,问题的正确答案是实施__len__
。正如问题中所述,原始示例已经实施。因此,它应该已经可以正常工作了。
可以在下面找到测试行为的完整示例(请注意 __len__
方法上方的注释):
from time import sleep
from tqdm import tqdm
class Iter:
def __init__(self, n=10):
self.n = n
self.iter = iter(range(n))
def __iter__(self):
return self
def __next__(self):
return next(self.iter)
# commenting the next two lines disables showing the bar
# due to tqdm not knowing the total number of elements:
def __len__(self):
return self.n
it = Iter()
for i in tqdm(it):
sleep(0.2)
看看 tqdm 到底做了什么:
try:
total = len(iterable)
except (TypeError, AttributeError):
total = None
...并且由于我们不确切知道@Duane 将什么用作 batches
,我认为这基本上只是一个隐藏得很好的错字(self.batches.len()
),这会导致 AttributeError
被 tqdm 捕获。
如果 batches
只是一个序列类型,那么这可能是预期的定义:
def __len__(self):
return len(self.batches)
__next__
的定义(使用len(self.batches)
)也指向这个方向
我正在实现自己的迭代器。 tqdm 不显示进度条,因为它不知道列表中元素的总数。我不想使用 "total=",因为它看起来很难看。相反,我更愿意向我的迭代器添加一些东西,tqdm 可以使用它来计算总数。
class Batches:
def __init__(self, batches, target_input):
self.batches = batches
self.pos = 0
self.target_input = target_input
def __iter__(self):
return self
def __next__(self):
if self.pos < len(self.batches):
minibatch = self.batches[self.pos]
target = minibatch[:, :, self.target_input]
self.pos += 1
return minibatch, target
else:
raise StopIteration
def __len__(self):
return self.batches.len()
这可能吗?在上面的代码中添加什么...
像下面这样使用 tqdm..
for minibatch, target in tqdm(Batches(test, target_input)):
output = lstm(minibatch)
loss = criterion(output, target)
writer.add_scalar('loss', loss, tensorboard_step)
我知道已经有一段时间了,但我一直在寻找相同的答案,这里是解决方案。而不是像这样用 tqdm 包装你的迭代
for i in tqdm(my_iterable):
do_something()
改用 "with" 关闭,如:
with tqdm(total=len(my_iterable)) as progress_bar:
for i in my_iterable:
do_something()
progress_bar.update(1) # update progress
对于您的批次,您可以将总数设置为批次数,并更新为 1(如上所述)。或者,您可以将 total 设置为实际的项目总数,将 update 设置为当前处理的批次的大小。
原来的问题是这样的:
I don't want to use "total=" as it looks ugly. Rather I would prefer to add something to my iterator that tqdm can use to figure out the total.
但是,total
:
with tqdm(total=len(my_iterable)) as progress_bar:
事实上,给定的例子比它需要的更复杂,因为原始问题没有要求复杂的柱形更新。因此,
for i in tqdm(my_iterable, total=my_total):
do_something()
实际上已经足够了(正如作者@emem 已经在评论中指出的那样)。
这个问题相对较老(写这篇文章时已有 4 年),但查看 tqdm 的代码,可以看出 already from the very beginning(写这篇文章时已有 8 年)的行为是如果未给出 total
,则默认为 total = len(iterable)
。
因此,问题的正确答案是实施__len__
。正如问题中所述,原始示例已经实施。因此,它应该已经可以正常工作了。
可以在下面找到测试行为的完整示例(请注意 __len__
方法上方的注释):
from time import sleep
from tqdm import tqdm
class Iter:
def __init__(self, n=10):
self.n = n
self.iter = iter(range(n))
def __iter__(self):
return self
def __next__(self):
return next(self.iter)
# commenting the next two lines disables showing the bar
# due to tqdm not knowing the total number of elements:
def __len__(self):
return self.n
it = Iter()
for i in tqdm(it):
sleep(0.2)
看看 tqdm 到底做了什么:
try:
total = len(iterable)
except (TypeError, AttributeError):
total = None
...并且由于我们不确切知道@Duane 将什么用作 batches
,我认为这基本上只是一个隐藏得很好的错字(self.batches.len()
),这会导致 AttributeError
被 tqdm 捕获。
如果 batches
只是一个序列类型,那么这可能是预期的定义:
def __len__(self):
return len(self.batches)
__next__
的定义(使用len(self.batches)
)也指向这个方向