Pytorch torch.cholesky 忽略异常

Pytorch torch.cholesky ignoring exception

对于我的批处理中的一些矩阵,由于矩阵是奇异的,我遇到了异常。

L = th.cholesky(Xt.bmm(X))

cholesky_cpu: For batch 51100: U(22,22) is zero, singular U

由于它们对我的用例来说很少,所以我想忽略异常并进一步处理它们。我会将结果计算设置为 nan 是否可能?

实际上,如果我 catch 异常并使用 continue 仍然没有完成其余批次的计算。

在 C++ 中使用 Pytorch libtorch 也会发生同样的情况。

执行 cholesky 分解时,PyTorch 依赖 LAPACK 获取 CPU 张量,依赖 MAGMA 获取 CUDA 张量。在 PyTorch code used to call LAPACK the batch is just iterated over, invoking LAPACK's zpotrs_ function on each matrix separately. In the PyTorch code used to call MAGMA the entire batch is processed using MAGMA's magma_dpotrs_batched 中,这可能比分别迭代每个矩阵更快。

AFAIK 无法指示 MAGMA 或 LAPACK 不引发异常(尽管公平地说,我不是这些软件包的专家)。由于 MAGMA 可能以某种方式利用批处理,我们可能不想只默认使用迭代方法,因为我们可能会因不执行批处理 cholesky 而失去性能。

一个可能的解决方案是首先尝试执行批处理 cholesky 分解,如果失败,那么我们可以对批处理中的每个元素执行 cholesky 分解,将失败的条目设置为 NaN。

def cholesky_no_except(x, upper=False, force_iterative=False):
    success = False
    if not force_iterative:
        try:
            results = torch.cholesky(x, upper=upper)
            success = True
        except RuntimeError:
            pass

    if not success:
        # fall back to operating on each element separately
        results_list = []
        x_batched = x.reshape(-1, x.shape[-2], x.shape[-1])
        for batch_idx in range(x_batched.shape[0]):
            try:
                result = torch.cholesky(x_batched[batch_idx, :, :], upper=upper)
            except RuntimeError:
                # may want to only accept certain RuntimeErrors add a check here if that's the case
                # on failure create a "nan" matrix
                result = float('nan') + torch.empty(x.shape[-2], x.shape[-1], device=x.device, dtype=x.dtype)
            results_list.append(result)
        results = torch.cat(results_list, dim=0).reshape(*x.shape)

    return results

如果您预计异常在 cholesky 分解期间很常见,您可能希望使用 force_iterative=True 跳过尝试使用批处理版本的初始调用,因为在那种情况下,此函数可能只是在浪费时间第一次尝试。

根据Pytorch Discuss forum无法捕获异常。

不幸的是,解决方案是实现我自己的 简单批处理 cholesky (th.cholesky(..., upper=False)),然后使用 th.isnan 处理 Nan 值。

import torch as th

# nograd cholesky
def cholesky(A):
    L = th.zeros_like(A)

    for i in range(A.shape[-1]):
        for j in range(i+1):
            s = 0.0
            for k in range(j):
                s = s + L[...,i,k] * L[...,j,k]

            L[...,i,j] = th.sqrt(A[...,i,i] - s) if (i == j) else \
                      (1.0 / L[...,j,j] * (A[...,i,j] - s))
    return L

我不知道这与发布的其他解决方案相比速度如何,但它可能更快。

首先使用torch.det判断你的batch中是否有奇异矩阵。然后屏蔽掉那些矩阵。

output = Xt.bmm(X)
dets = torch.det(output)

# if output is of shape (bs, x, y), dets will be of shape (bs)
bad_idxs = dets==0 #might want an allclose here

output[bad_idxs] = 1. # fill singular matrices with 1s

L = torch.cholesky(output)

在你可能需要处理你用 1 填充的奇异矩阵之后,但你有它们的索引值,所以很容易抓住它们或排除它们。