Rust 中的快速惯用 Floyd-Warshall 算法
Fast idiomatic Floyd-Warshall algorithm in Rust
我正在尝试在 Rust 中实现 Floyd-Warshall 算法的相当快的版本。该算法在有向加权图中找到所有顶点之间的最短路径。
算法的主要部分可以这样写:
// dist[i][j] contains edge length between vertices [i] and [j]
// after the end of the execution it contains shortest path between [i] and [j]
fn floyd_warshall(dist: &mut [Vec<i32>]) {
let n = dist.len();
for i in 0..n {
for j in 0..n {
for k in 0..n {
dist[j][k] = min(dist[j][k], dist[j][i] + dist[i][k]);
}
}
}
}
此实现非常简短且易于理解,但比类似的 C++ 实现慢 1.5 倍。
据我了解,问题是在每次访问向量时,Rust 都会检查索引是否在向量的范围内,这会增加一些开销。
我用 get_unchecked* 函数重写了这个函数:
fn floyd_warshall_unsafe(dist: &mut [Vec<i32>]) {
let n = dist.len();
for i in 0..n {
for j in 0..n {
for k in 0..n {
unsafe {
*dist[j].get_unchecked_mut(k) = min(
*dist[j].get_unchecked(k),
dist[j].get_unchecked(i) + dist[i].get_unchecked(k),
)
}
}
}
}
}
它真的开始工作快了 1.5 倍 (full code of the test)。
我没想到边界检查会增加那么多开销:(
是否可以在没有不安全的情况下以惯用的方式重写此代码,使其运行速度与不安全版本一样快?例如。是否可以通过向代码中添加一些断言来向编译器“证明”不存在越界访问?
乍一看,人们会希望这就足够了:
fn floyd_warshall(dist: &mut [Vec<i32>]) {
let n = dist.len();
for i in 0..n {
assert!(i < dist.len());
for j in 0..n {
assert!(j < dist.len());
assert!(i < dist[j].len());
let v2 = dist[j][i];
for k in 0..n {
assert!(k < dist[i].len());
assert!(k < dist[j].len());
dist[j][k] = min(dist[j][k], v2 + dist[i][k]);
}
}
}
}
添加断言是让 Rust 优化器相信变量确实在边界内的已知技巧。但是,它在这里不起作用。我们需要做的是以某种方式让 Rust 编译器更清楚地知道这些循环在边界内,而无需求助于深奥的代码。
为了实现这一点,我按照 David Eisenstat 的建议移动到二维数组:
fn floyd_warshall<const N:usize>(mut dist: Box<[[i32; N]; N]>) -> Box<[[i32; N]; N]> {
for i in 0..N {
for j in 0..N {
for k in 0..N {
dist[j][k] = min(dist[j][k], dist[j][i] + dist[i][k]);
}
}
}
dist
}
这使用常量泛型(Rust 的一个相对较新的特性)来指定堆上给定二维数组的大小。就其本身而言,此更改在我的机器上效果很好(比 usafe 快 100ms,比 unsafe 慢约 20ms)。此外,如果您将 v2 计算移动到 k 循环之外,如下所示:
fn floyd_warshall<const N:usize>(mut dist: Box<[[i32; N]; N]>) -> Box<[[i32; N]; N]> {
for i in 0..N {
for j in 0..N {
let v2 = dist[j][i];
for k in 0..N {
dist[j][k] = min(dist[j][k], v2 + dist[i][k]);
}
}
}
dist
}
改进很大(在我的机器上从 ~300ms 到 ~100ms)。同样的优化适用于 floyd_warshall_unsafe
,在我的机器上平均达到 ~100ms。检查程序集时(在 floyd_warshall 上使用 #[inline(never)]
),两者似乎都没有进行边界检查,并且两者看起来都在某种程度上被矢量化了。虽然,我不是汇编专家。
因为这是一个非常热的循环(最多进行三个边界检查),所以性能受到如此大的影响我并不感到惊讶。不幸的是,在这种情况下索引的使用非常复杂,以至于无法通过 assert 技巧轻松解决问题。还有其他已知的情况,其中需要断言检查来提高性能,但编译器无法充分使用这些信息。 Here is one such example.
Here is the playground 我的改动。
经过一些实验,根据 , and comments in related issue 中建议的想法,我找到了解决方案,其中:
- 仍然使用相同的接口(例如
&mut [Vec<i32>]
作为参数)
- 不使用不安全
- 比不安全版本快 3-4 倍
- 很丑:(
代码如下所示:
fn floyd_warshall_fast(dist: &mut [Vec<i32>]) {
let n = dist.len();
for i in 0..n {
for j in 0..n {
if i == j {
continue;
}
let (dist_j, dist_i) = if j < i {
let (lo, hi) = dist.split_at_mut(i);
(&mut lo[j][..n], &mut hi[0][..n])
} else {
let (lo, hi) = dist.split_at_mut(j);
(&mut hi[0][..n], &mut lo[i][..n])
};
let dist_ji = dist_j[i];
for k in 0..n {
dist_j[k] = min(dist_j[k], dist_ji + dist_i[k]);
}
}
}
}
里面有几个想法:
- 我们计算
dist_ji
一次,因为它不会在最内层循环内发生变化,编译器不需要考虑它。
- 我们“证明”
dist[i]
和dist[j]
其实是两个不同的向量。这是通过这个丑陋的 split_at_mut
事情和 i == j
特例完成的(真的很想知道一个更简单的解决方案)。之后我们可以完全分开对待 dist[i]
和 dist[j]
,例如编译器可以向量化这个循环,因为它知道数据不重叠。
- 最后一个技巧是向编译器“证明”
dist[i]
和 dist[j]
都至少有 n
个元素。这是在计算 dist[i]
和 dist[j]
时由 [..n]
完成的(例如,我们使用 &mut lo[j][..n]
而不是 &mut lo[j]
)。在那之后,编译器明白内循环永远不会使用越界值,并删除检查。
有趣的是,只有当所有三个优化都使用时,它才能大大提高速度。如果我们只使用其中的任意两个,编译器将无法对其进行优化。
我正在尝试在 Rust 中实现 Floyd-Warshall 算法的相当快的版本。该算法在有向加权图中找到所有顶点之间的最短路径。
算法的主要部分可以这样写:
// dist[i][j] contains edge length between vertices [i] and [j]
// after the end of the execution it contains shortest path between [i] and [j]
fn floyd_warshall(dist: &mut [Vec<i32>]) {
let n = dist.len();
for i in 0..n {
for j in 0..n {
for k in 0..n {
dist[j][k] = min(dist[j][k], dist[j][i] + dist[i][k]);
}
}
}
}
此实现非常简短且易于理解,但比类似的 C++ 实现慢 1.5 倍。
据我了解,问题是在每次访问向量时,Rust 都会检查索引是否在向量的范围内,这会增加一些开销。
我用 get_unchecked* 函数重写了这个函数:
fn floyd_warshall_unsafe(dist: &mut [Vec<i32>]) {
let n = dist.len();
for i in 0..n {
for j in 0..n {
for k in 0..n {
unsafe {
*dist[j].get_unchecked_mut(k) = min(
*dist[j].get_unchecked(k),
dist[j].get_unchecked(i) + dist[i].get_unchecked(k),
)
}
}
}
}
}
它真的开始工作快了 1.5 倍 (full code of the test)。
我没想到边界检查会增加那么多开销:(
是否可以在没有不安全的情况下以惯用的方式重写此代码,使其运行速度与不安全版本一样快?例如。是否可以通过向代码中添加一些断言来向编译器“证明”不存在越界访问?
乍一看,人们会希望这就足够了:
fn floyd_warshall(dist: &mut [Vec<i32>]) {
let n = dist.len();
for i in 0..n {
assert!(i < dist.len());
for j in 0..n {
assert!(j < dist.len());
assert!(i < dist[j].len());
let v2 = dist[j][i];
for k in 0..n {
assert!(k < dist[i].len());
assert!(k < dist[j].len());
dist[j][k] = min(dist[j][k], v2 + dist[i][k]);
}
}
}
}
添加断言是让 Rust 优化器相信变量确实在边界内的已知技巧。但是,它在这里不起作用。我们需要做的是以某种方式让 Rust 编译器更清楚地知道这些循环在边界内,而无需求助于深奥的代码。
为了实现这一点,我按照 David Eisenstat 的建议移动到二维数组:
fn floyd_warshall<const N:usize>(mut dist: Box<[[i32; N]; N]>) -> Box<[[i32; N]; N]> {
for i in 0..N {
for j in 0..N {
for k in 0..N {
dist[j][k] = min(dist[j][k], dist[j][i] + dist[i][k]);
}
}
}
dist
}
这使用常量泛型(Rust 的一个相对较新的特性)来指定堆上给定二维数组的大小。就其本身而言,此更改在我的机器上效果很好(比 usafe 快 100ms,比 unsafe 慢约 20ms)。此外,如果您将 v2 计算移动到 k 循环之外,如下所示:
fn floyd_warshall<const N:usize>(mut dist: Box<[[i32; N]; N]>) -> Box<[[i32; N]; N]> {
for i in 0..N {
for j in 0..N {
let v2 = dist[j][i];
for k in 0..N {
dist[j][k] = min(dist[j][k], v2 + dist[i][k]);
}
}
}
dist
}
改进很大(在我的机器上从 ~300ms 到 ~100ms)。同样的优化适用于 floyd_warshall_unsafe
,在我的机器上平均达到 ~100ms。检查程序集时(在 floyd_warshall 上使用 #[inline(never)]
),两者似乎都没有进行边界检查,并且两者看起来都在某种程度上被矢量化了。虽然,我不是汇编专家。
因为这是一个非常热的循环(最多进行三个边界检查),所以性能受到如此大的影响我并不感到惊讶。不幸的是,在这种情况下索引的使用非常复杂,以至于无法通过 assert 技巧轻松解决问题。还有其他已知的情况,其中需要断言检查来提高性能,但编译器无法充分使用这些信息。 Here is one such example.
Here is the playground 我的改动。
经过一些实验,根据
- 仍然使用相同的接口(例如
&mut [Vec<i32>]
作为参数) - 不使用不安全
- 比不安全版本快 3-4 倍
- 很丑:(
代码如下所示:
fn floyd_warshall_fast(dist: &mut [Vec<i32>]) {
let n = dist.len();
for i in 0..n {
for j in 0..n {
if i == j {
continue;
}
let (dist_j, dist_i) = if j < i {
let (lo, hi) = dist.split_at_mut(i);
(&mut lo[j][..n], &mut hi[0][..n])
} else {
let (lo, hi) = dist.split_at_mut(j);
(&mut hi[0][..n], &mut lo[i][..n])
};
let dist_ji = dist_j[i];
for k in 0..n {
dist_j[k] = min(dist_j[k], dist_ji + dist_i[k]);
}
}
}
}
里面有几个想法:
- 我们计算
dist_ji
一次,因为它不会在最内层循环内发生变化,编译器不需要考虑它。 - 我们“证明”
dist[i]
和dist[j]
其实是两个不同的向量。这是通过这个丑陋的split_at_mut
事情和i == j
特例完成的(真的很想知道一个更简单的解决方案)。之后我们可以完全分开对待dist[i]
和dist[j]
,例如编译器可以向量化这个循环,因为它知道数据不重叠。 - 最后一个技巧是向编译器“证明”
dist[i]
和dist[j]
都至少有n
个元素。这是在计算dist[i]
和dist[j]
时由[..n]
完成的(例如,我们使用&mut lo[j][..n]
而不是&mut lo[j]
)。在那之后,编译器明白内循环永远不会使用越界值,并删除检查。
有趣的是,只有当所有三个优化都使用时,它才能大大提高速度。如果我们只使用其中的任意两个,编译器将无法对其进行优化。