Theano 共享更新 python 中的最后一个元素

Theano shared updating last element in python

我有一个共享变量 persistent_vis_chain,它由一个 theano 函数更新,它从 theano.scan 获取函数,但这不是背景故事的问题。

我的共享变量看起来像 D = [image1, ... , imageN] ,其中每个图像都是 [x1,x2,...,x784]。

我要做的是对所有图片取平均值,放到最后一张图片N中。那就是我想对每个图像中除最后一个 1 之外的所有值求和,这将导致 [s1,s2,...,s784] 然后我想设置 imageN = [s1/len(D),s2/len(D),...s784/len(D)]

所以我的问题是我不知道如何使用 theano.shared 执行此操作,并且可能与我对 theano 函数的理解以及使用符号变量进行此计算有关。任何帮助将不胜感激。

如果您有 N 个图像,每个图像的形状都是 28x28=784,那么您的共享变量的形状大概是 (N,28,28)(N,784)?此方法适用于任何一种形状。

假设 D 是包含图像数据的共享变量。如果你想得到平均图像那么D.mean(keepdims=True)会象征性地给你。

不清楚您是想将最终图像更改为等于平均图像(听起来很奇怪),还是想向共享变量添加更多 N+1 图像.对于前者,您可以这样做:

D = theano.shared(load_D_data())
D_update_expression = do_something_with_scan_to_get_D_update_expression(D)
updates = [(D, T.concatenate(D_update_expression[:-1],
                             D_update_expression.mean(keepdims=True)))]
f = theano.function(..., updates=updates)

如果您想执行后者(添加额外的图像),请按如下方式更改 updates 行:

updates = [(D, T.concatenate(D_update_expression,
                             D_update_expression.mean(keepdims=True)))]

请注意,此代码仅供参考。它可能无法正常工作(例如,您可能需要弄乱 T.concatenate 命令中的 axis= 参数)。

重点是您需要构造一个符号表达式来解释 D 的新值是什么样的。您希望它是扫描更新加上这个额外的平均事物的组合。 T.concatenate 允许您将这两个部分组合在一起。