比较 Rust 中 ndarray 数组的形状

Compare shapes of ndarray Arrays in Rust

我是 Rust 新手。

假设矩阵 a 的形状为 (n1, n2)b 的形状为 (m1, m2)c 的形状为 (k1, k2)。 我想检查 ab 可以相乘(作为矩阵)并且 a * b 的形状等于 c。换句话说,(n2 == m1) && (n1 == k1) && (m2 == k2).

use ndarray::Array2;

// a : Array2<i64>
// b : Array2<i64>
// c : Array2<i64>

.shape method returns 数组的形状为切片。 简洁的方法是什么?

.shape() 返回的数组是否保证长度为 2,还是我应该检查一下?如果有保证,有没有办法跳过 None 检查?

let n1 = a.shape().get(0);  // this is Optional<i64>

Array2 具体有 .ncols() and .nrows() methods。如果您只使用二维数组,那么这可能是最佳选择。它们 return 使用,因此不需要 None 检查。

use ndarray::prelude::*;

fn is_valid_matmul(a: &Array2<i64>, b: &Array2<i64>, c: &Array2<i64>) -> bool {
    //nrows() and ncols() are only valid for Array2, 
    //[arr.nrows(), arr.ncols()] = [arr.shape()[0], arr.shape()[1]]
    return a.ncols() == b.nrows() && b.ncols() == c.ncols() && a.nrows() == c.nrows();
}
fn main() {
    let a = Array2::<i64>::zeros((3, 5));
    let b = Array2::<i64>::zeros((5, 6));
    let c_valid = Array2::<i64>::zeros((3, 6));
    let c_invalid = Array2::<i64>::zeros((8, 6));

    println!("is_valid_matmul(&a, &b, &c_valid) = {}", is_valid_matmul(&a, &b, &c_valid));
    println!("is_valid_matmul(&a, &b, &c_invalid) = {}", is_valid_matmul(&a, &b, &c_invalid));
}
/*
output:
is_valid_matmul(&a, &b, &c_valid) = true
is_valid_matmul(&a, &b, &c_invalid) = false
*/