用于张量索引的 Idris 非平凡类型计算

Idris non-trivial type computation for tensor indexing

我一直在摆弄一个简单的张量库,我在其中定义了以下类型。

data Tensor : Vect n Nat -> Type -> Type where
  Scalar : a -> Tensor [] a
  Dimension : Vect n (Tensor d a) -> Tensor (n :: d) a

类型的向量参数描述了张量的"dimensions"或"shape"。我目前正在尝试定义一个函数来安全地索引到 Tensor。我曾计划使用 Fins 来执行此操作,但我 运行 遇到了问题。因为 Tensor 的顺序未知,所以我可能需要任意数量的索引,每个索引都需要不同的上限。这意味着 Vect 个索引是不够的,因为每个索引都有不同的类型。这促使我考虑使用元组(在 Idris 中称为 "pairs"?)。我编写了以下函数来计算必要的类型。

TensorIndex : Vect n Nat -> Type
TensorIndex []      = ()
TensorIndex (d::[]) = Fin d
TensorIndex (d::ds) = (Fin d, TensorIndex ds)

这个函数如我所料,从维度向量计算出合适的索引类型。

> TensorIndex [4,4,3] -- (Fin 4, Fin 4, Fin 3)
> TensorIndex [2] -- Fin 2
> TensorIndex [] -- ()

但是当我试图定义实际的 index 函数时...

index : {d : Vect n Nat} -> TensorIndex d -> Tensor d a -> a
index () (Scalar x) = x
index (a,as) (Dimension xs) = index as $ index a xs
index a (Dimension xs) with (index a xs) | Tensor x = x

...Idris 在第二种情况下引发了以下错误(奇怪的是,第一种情况似乎完全没问题)。

Type mismatch between
         (A, B) (Type of (a,as))
and
         TensorIndex (n :: d) (Expected type)

该错误似乎暗示,它没有将 TensorIndex 视为极其复杂的类型同义词并像我希望的那样对其进行评估,而是将其视为用 data 定义的宣言; "black-box type" 可以这么说。伊德里斯在哪里划清界限?有什么方法可以让我重写 TensorIndex 以便它按照我想要的方式工作吗?如果没有,你能想出其他方法来编写index函数吗?

如果您在 TensorIndex 中允许尾随 (),您的生活会变得更加轻松,从那时起您就可以做到

TensorIndex : Vect n Nat -> Type
TensorIndex []      = ()
TensorIndex (d::ds) = (Fin d, TensorIndex ds)

index : {ds : Vect n Nat} -> TensorIndex ds -> Tensor ds a -> a
index {ds = []} () (Scalar x) = x
index {ds = _ :: ds} (i, is) (Dimension xs) = index is (index i xs)

如果您想保留 TensorIndex 的定义,您需要为 ds = [_]ds = _::_::_ 设置单独的大小写以匹配 TensorIndex 的结构:

TensorIndex : Vect n Nat -> Type
TensorIndex []      = ()
TensorIndex (d::[]) = Fin d
TensorIndex (d::ds) = (Fin d, TensorIndex ds)

index : {ds : Vect n Nat} -> TensorIndex ds -> Tensor ds a -> a
index {ds = []} () (Scalar x) = x
index {ds = _ :: []} i (Dimension xs) with (index i xs) | (Scalar x) = x
index {ds = _ :: _ :: _} (i, is) (Dimension xs) = index is (index i xs)

这个有效而你的无效的原因是因为在这里,index 的每个案例恰好对应一个 TensorIndex 案例,因此 TensorIndex ds 可以减少。

如果您通过对维度列表的归纳来定义 Tensor,而将 Index 定义为数据类型,则您的定义会更清晰。

确实,目前您被迫对 Vect n Nat 类型的隐式参数进行模式匹配,以查看索引的形状。但是,如果索引直接定义为一段数据,那么它会 约束 它索引到的结构的形状,并且一切都会到位:正确的信息在正确的时间到达让类型检查器开心。

module Tensor

import Data.Fin
import Data.Vect

tensor : Vect n Nat -> Type -> Type
tensor []        a = a
tensor (m :: ms) a = Vect m (tensor ms a)

data Index : Vect n Nat -> Type where
  Here : Index []
  At   : Fin m -> Index ms -> Index (m :: ms)

index : Index ms -> tensor ms a -> a
index Here     a = a
index (At k i) v = index i $ index k v