优化循环搜索最近点
Optimize for loops searching for closest points
我正在尝试优化耗时过长的 for 循环。我确信它可以被优化,但由于我是 R 的新手,我不确定该怎么做。
我有两个矩阵,tarU_x
和 src_x
。对于 tarU_x
中的每一行,我想在 src_x
中找到最接近的一行并分配相同的标签(我在 src_y
中有 src_x
的标签,以及估计的标签因为 tarU_x
将在 tarU_y
).
我正在使用经典的嵌套 for 循环来做这件事,它不是很有效,所以我想利用 R 提供的可能性。代码如下:
# Estimate tarU_y
tarU_y <- vector()
for (i in 1:nrow(tarU_x)) {
tarU_vector <- tarU_x[i,]
lowest_dist <- Inf
lowest_dist_class <- -1
for (j in 1:nrow(src_x)) {
distance <- dist(rbind(tarU_vector, src_x[j,]))
if (distance < lowest_dist) {
lowest_dist <- distance
lowest_dist_class <- src_y[j]
}
}
tarU_y[i] <- lowest_dist_class
}
编辑
我尝试使用 apply
,正如 s__ 所建议的那样,并使其正常工作,最终得到以下代码:
distances <- apply(src_x, 1, function (y) apply(tarU_x, 1, function(x) dist(rbind(x,y))))
tarU_y <- apply(distances, 1, function(x) src_y[which.min(x)])
不过好像慢了点,估计底层代码差不多。例如,使用 for 循环的测试耗时 14 秒,而使用 apply
耗时 16 秒。
有关更多信息,我使用的数据是此处提供的数据:https://archive.ics.uci.edu/ml/datasets/Gas+Sensor+Array+Drift+Dataset+at+Different+Concentrations,分为 10 个不同的批次,每个样本有 128 个特征。
尝试 library(pracma)
中的 distmat
功能:
library(pracma)
tarU_y <- src_y[max.col(-distmat(tarU_x, src_x))]
编辑:添加了基准
使用随机法线矩阵进行说明:
library(pracma)
library(microbenchmark)
set.seed(123)
tarU_x <- matrix(rnorm(1e4, mean = rep(1:100, 10)), nrow = 100L)
src_x <- matrix(rnorm(2e4, mean = rep(200:1, 10)/2), nrow = 200L)
src_y <- rep(200:1, 10)/2
using.forloop <- function(x1, x2, y1) {
y2 <- rep(y1[1], nrow(x2))
for (i in 1:nrow(x2)) {
lowest_dist <- dist(rbind(x1[1,], x2[i,]))
for (j in 2:nrow(x1)) {
distance <- dist(rbind(x1[j,], x2[i,]))
if (distance < lowest_dist) {
lowest_dist <- distance
y2[i] <- y1[j]
}
}
}
return(y2)
}
using.distmat <- function(x1, x2, y1) {
return(y1[max.col(-distmat(x2, x1))])
}
all.equal(using.forloop(src_x, tarU_x, src_y), using.distmat(src_x, tarU_x, src_y))
[1] TRUE
microbenchmark(using.forloop(src_x, tarU_x, src_y), using.distmat(src_x, tarU_x, src_y))
Unit: milliseconds
expr min lq mean median uq max neval
using.forloop(src_x, tarU_x, src_y) 415.8176 447.95200 473.345159 462.0715 495.33775 609.8592 100
using.distmat(src_x, tarU_x, src_y) 2.4413 2.59575 2.779786 2.7072 2.91965 3.8540 100
我正在尝试优化耗时过长的 for 循环。我确信它可以被优化,但由于我是 R 的新手,我不确定该怎么做。
我有两个矩阵,tarU_x
和 src_x
。对于 tarU_x
中的每一行,我想在 src_x
中找到最接近的一行并分配相同的标签(我在 src_y
中有 src_x
的标签,以及估计的标签因为 tarU_x
将在 tarU_y
).
我正在使用经典的嵌套 for 循环来做这件事,它不是很有效,所以我想利用 R 提供的可能性。代码如下:
# Estimate tarU_y
tarU_y <- vector()
for (i in 1:nrow(tarU_x)) {
tarU_vector <- tarU_x[i,]
lowest_dist <- Inf
lowest_dist_class <- -1
for (j in 1:nrow(src_x)) {
distance <- dist(rbind(tarU_vector, src_x[j,]))
if (distance < lowest_dist) {
lowest_dist <- distance
lowest_dist_class <- src_y[j]
}
}
tarU_y[i] <- lowest_dist_class
}
编辑
我尝试使用 apply
,正如 s__ 所建议的那样,并使其正常工作,最终得到以下代码:
distances <- apply(src_x, 1, function (y) apply(tarU_x, 1, function(x) dist(rbind(x,y))))
tarU_y <- apply(distances, 1, function(x) src_y[which.min(x)])
不过好像慢了点,估计底层代码差不多。例如,使用 for 循环的测试耗时 14 秒,而使用 apply
耗时 16 秒。
有关更多信息,我使用的数据是此处提供的数据:https://archive.ics.uci.edu/ml/datasets/Gas+Sensor+Array+Drift+Dataset+at+Different+Concentrations,分为 10 个不同的批次,每个样本有 128 个特征。
尝试 library(pracma)
中的 distmat
功能:
library(pracma)
tarU_y <- src_y[max.col(-distmat(tarU_x, src_x))]
编辑:添加了基准
使用随机法线矩阵进行说明:
library(pracma)
library(microbenchmark)
set.seed(123)
tarU_x <- matrix(rnorm(1e4, mean = rep(1:100, 10)), nrow = 100L)
src_x <- matrix(rnorm(2e4, mean = rep(200:1, 10)/2), nrow = 200L)
src_y <- rep(200:1, 10)/2
using.forloop <- function(x1, x2, y1) {
y2 <- rep(y1[1], nrow(x2))
for (i in 1:nrow(x2)) {
lowest_dist <- dist(rbind(x1[1,], x2[i,]))
for (j in 2:nrow(x1)) {
distance <- dist(rbind(x1[j,], x2[i,]))
if (distance < lowest_dist) {
lowest_dist <- distance
y2[i] <- y1[j]
}
}
}
return(y2)
}
using.distmat <- function(x1, x2, y1) {
return(y1[max.col(-distmat(x2, x1))])
}
all.equal(using.forloop(src_x, tarU_x, src_y), using.distmat(src_x, tarU_x, src_y))
[1] TRUE
microbenchmark(using.forloop(src_x, tarU_x, src_y), using.distmat(src_x, tarU_x, src_y))
Unit: milliseconds
expr min lq mean median uq max neval
using.forloop(src_x, tarU_x, src_y) 415.8176 447.95200 473.345159 462.0715 495.33775 609.8592 100
using.distmat(src_x, tarU_x, src_y) 2.4413 2.59575 2.779786 2.7072 2.91965 3.8540 100