比较 Rust 中 ndarray 数组的形状
Compare shapes of ndarray Arrays in Rust
我是 Rust 新手。
假设矩阵 a
的形状为 (n1, n2)
,b
的形状为 (m1, m2)
,c
的形状为 (k1, k2)
。
我想检查 a
和 b
可以相乘(作为矩阵)并且 a * b
的形状等于 c
。换句话说,(n2 == m1) && (n1 == k1) && (m2 == k2)
.
use ndarray::Array2;
// a : Array2<i64>
// b : Array2<i64>
// c : Array2<i64>
.shape
method returns 数组的形状为切片。
简洁的方法是什么?
从 .shape()
返回的数组是否保证长度为 2,还是我应该检查一下?如果有保证,有没有办法跳过 None
检查?
let n1 = a.shape().get(0); // this is Optional<i64>
Array2 具体有 .ncols() and .nrows() methods。如果您只使用二维数组,那么这可能是最佳选择。它们 return 使用,因此不需要 None
检查。
use ndarray::prelude::*;
fn is_valid_matmul(a: &Array2<i64>, b: &Array2<i64>, c: &Array2<i64>) -> bool {
//nrows() and ncols() are only valid for Array2,
//[arr.nrows(), arr.ncols()] = [arr.shape()[0], arr.shape()[1]]
return a.ncols() == b.nrows() && b.ncols() == c.ncols() && a.nrows() == c.nrows();
}
fn main() {
let a = Array2::<i64>::zeros((3, 5));
let b = Array2::<i64>::zeros((5, 6));
let c_valid = Array2::<i64>::zeros((3, 6));
let c_invalid = Array2::<i64>::zeros((8, 6));
println!("is_valid_matmul(&a, &b, &c_valid) = {}", is_valid_matmul(&a, &b, &c_valid));
println!("is_valid_matmul(&a, &b, &c_invalid) = {}", is_valid_matmul(&a, &b, &c_invalid));
}
/*
output:
is_valid_matmul(&a, &b, &c_valid) = true
is_valid_matmul(&a, &b, &c_invalid) = false
*/
我是 Rust 新手。
假设矩阵 a
的形状为 (n1, n2)
,b
的形状为 (m1, m2)
,c
的形状为 (k1, k2)
。
我想检查 a
和 b
可以相乘(作为矩阵)并且 a * b
的形状等于 c
。换句话说,(n2 == m1) && (n1 == k1) && (m2 == k2)
.
use ndarray::Array2;
// a : Array2<i64>
// b : Array2<i64>
// c : Array2<i64>
.shape
method returns 数组的形状为切片。
简洁的方法是什么?
从 .shape()
返回的数组是否保证长度为 2,还是我应该检查一下?如果有保证,有没有办法跳过 None
检查?
let n1 = a.shape().get(0); // this is Optional<i64>
Array2 具体有 .ncols() and .nrows() methods。如果您只使用二维数组,那么这可能是最佳选择。它们 return 使用,因此不需要 None
检查。
use ndarray::prelude::*;
fn is_valid_matmul(a: &Array2<i64>, b: &Array2<i64>, c: &Array2<i64>) -> bool {
//nrows() and ncols() are only valid for Array2,
//[arr.nrows(), arr.ncols()] = [arr.shape()[0], arr.shape()[1]]
return a.ncols() == b.nrows() && b.ncols() == c.ncols() && a.nrows() == c.nrows();
}
fn main() {
let a = Array2::<i64>::zeros((3, 5));
let b = Array2::<i64>::zeros((5, 6));
let c_valid = Array2::<i64>::zeros((3, 6));
let c_invalid = Array2::<i64>::zeros((8, 6));
println!("is_valid_matmul(&a, &b, &c_valid) = {}", is_valid_matmul(&a, &b, &c_valid));
println!("is_valid_matmul(&a, &b, &c_invalid) = {}", is_valid_matmul(&a, &b, &c_invalid));
}
/*
output:
is_valid_matmul(&a, &b, &c_valid) = true
is_valid_matmul(&a, &b, &c_invalid) = false
*/