具有任意嵌套 Vects 的张量定义

Tensor definition with arbitrarily nested Vects

我正在尝试创建一个 Tensor 类型,但我在使用构造函数的类型签名时遇到了问题。在 and question, they define a Tensor as a Vect of Tensors, and in this 问题中作为嵌套 Vect 的类型别名,但两者都不适合我的目的。我需要一个 Tensor 是原子的(它不是由其他 Tensor 组成)和一个不同的类型(它不会因为是别名而继承方法)。

我尝试了以下方法,它通过 array_type 从任意嵌套的 Vect 中隐式提取形状和数据类型,并将其包装在最小的 Tensor 类型

import Data.Vect

total array_type: (shape: Vect r Nat) -> (dtype: Type) -> Type
array_type [] dtype = dtype
array_type (d :: ds) dtype = Vect d (array_type ds dtype)

data Tensor : (shape: Vect r Nat) -> (dtype: Type) -> Type where
  MkTensor : array_type shape dtype -> Tensor shape dtype

然后我定义了各种函数来检查它是否正常工作(这里不包括)。所有这些编译都很好,但是当我尝试定义一个函数来将每个元素乘以二时,我陷入了真正的纠结。我试图首先在嵌套的 Vect:

上定义它
times_two : Num dtype => array_type shape dtype -> array_type shape dtype
times_two (x :: xs) = (times_two x) :: (times_two xs)
times_two x = 2 * x

但我明白了

When checking left hand side of times_two:
When checking an application of Main.times_two:
Can't disambiguate since no name has a suitable type:
Prelude.List.::, Prelude.Stream.::, Data.Vect.::

Data.Vect.:: 替换 :: 没有帮助。我正在尝试做的事情可能吗?明智?

您无法匹配 array_type shape dtype,因为它不是数据类型。在该类型简化为数据类型之前,您需要弄清楚(即匹配)shape 是什么。

times_two {shape = []} x = 2 * x
times_two {shape = n :: ns} xs = map times_two xs

(在这种情况下,xs 上的匹配项在 map 内。)