在集群上缓存中间值

cache intermediate values on cluster

我正在 运行ning dask distributed 并希望保存中间值。例如,当我 运行 这个:

from distributed import Client
client = Client(IP_PORT_TO_SCHEDULER)
from dask import delayed

@delayed(pure=True)
def myfunction(a):
    print("recomputing")
    return a + 3

res = myfunction(1)
res2 = res**2
res3 = client.persist(res2)

resagain = res**3
resagain2 = client.persist(resagain)

我希望 "recompute" 只打印一次。但是,在这种情况下,它会打印两次。我认为这可能是因为客户端没有缓存这个中间值。比如运行ning client.has_what(),我看到这个:

{'tcp://xx.xx.xx.xx:xxxx': ['pow-9d66a68ce8be79ff9cca17a2dc58aa0b',
'pow-440784f1abedb14511aa0d633935b55a']}

我看到了幂函数的最终结果,但没有看到中间计算。有没有办法强制客户端存储这个中间计算?谢谢!

Dask 将保留您明确保留的所有结果。将清除任何中间结果以节省内存。

因此,在您的情况下,您可能需要执行以下操作:

res = myfunction(1)
res = res.persist()  # Ask Dask to keep this in memory
...