优秀的指标包提供了一个计算平均精度的函数:apk
.
问题是,它基于for
循环,而且速度很慢:
require('Metrics')
require('rbenchmark')
actual <- 1:20000
predicted <- c(1:20, 200:600, 900:1522, 14000:32955)
benchmark(replications=10,
apk(5000, actual, predicted),
columns= c("test", "replications", "elapsed", "relative"))
test replications elapsed relative
1 apk(5000, actual, predicted) 10 53.68 1
我不知道如何矢量化这个函数,但我想知道是否有更好的方法在 R 中实现它。
我不得不同意实现看起来很糟糕......试试这个:
apk2 <- function (k, actual, predicted) {
predicted <- head(predicted, k)
is.new <- rep(FALSE, length(predicted))
is.new[match(unique(predicted), predicted)] <- TRUE
is.relevant <- predicted %in% actual & is.new
score <- sum(cumsum(is.relevant) * is.relevant / seq_along(predicted)) /
min(length(actual), k)
score
}
benchmark(replications=10,
apk(5000, actual, predicted),
apk2(5000, actual, predicted),
columns= c("test", "replications", "elapsed", "relative"))
# test replications elapsed relative
# 1 apk(5000, actual, predicted) 10 62.194 2961.619
# 2 apk2(5000, actual, predicted) 10 0.021 1.000
identical(apk(5000, actual, predicted),
apk2(5000, actual, predicted))
# [1] TRUE
I happen to have average precision code written using for loop. I think it is fast enough.
ap <- function(prediction) {
#prediction is a two column matrix. The first one is the true label and the second one is the prediction value
result = 0
ranklist <- prediction[sort(prediction[,2],decreasing=TRUE, index.return=TRUE)$ix,]
numpos <- length(which(ranklist[,1]==1))
deltaRecall <- 1/numpos
pcount <- 0
for(i in 1:nrow(ranklist)) {
if(ranklist[i,1] == 1) {
pcount <- pcount + 1
precision <- pcount/i
result <- result + precision*deltaRecall
}
}
return(result)
}
ap_at_N <- function(prediction, N=20) {
#average precision at N
result = 0
ranklist <- prediction[sort(prediction[,2],decreasing=TRUE, index.return=TRUE)$ix,]
numpos <- length(which(ranklist[,1]==1))
numpos <- min(N, numpos)
deltaRecall <- 1/numpos
pcount <- 0
for(i in 1:(min(nrow(ranklist),N))) {
if(ranklist[i,1] == 1) {
pcount <- pcount + 1
precision <- pcount/i
result <- result + precision*deltaRecall
}
}
return(result)
}