矢量化逻辑以检查索引是否匹配



我具有以下功能完美工作的功能,但是我想将矢量化应用于...

for i = 1:size(centroids,1)
    centroids(i, :) = mean(X(idx == i, :));
end

它检查idx是否匹配当前索引,如果确实如此,则计算与该索引相对应的所有X值的mean值。

这是我的矢量化尝试,我的解决方案不起作用,我知道为什么...

centroids = mean(X(idx == [1:size(centroids,1)], :));

以下idx == [1:size(centroids,1)]打破了代码。我不知道如何检查idx是否等于1size(centroids,1)的任何一个数字。

TL:DR

通过矢量化摆脱for循环

一个选项是使用 arrayfun;

nIdx      = size(centroids,1);
centroids = arrayfun(@(ii) mean(X(idx==ii,:)),1:nIdx, 'UniformOutput', false);
centroids = vertcat(centroids{:})

由于单个功能调用的输出不一定是标量,因此必须将UniformOutput选项设置为false。因此,arrayfun返回一个单元格数组,您需要vertcat以获取所需的双数组。

您可以使用cellfun将矩阵分为单元格,并从每个单元格中均值(在其内部操作中应用一个环(:

生成数据:

dim = 10;
N = 400;
nc = 20;
idx = randi(nc,[N 1]);
X = rand(N,dim);
centroids = zeros(nc,dim);

平均使用循环(问题的方法(

for i = 1:size(centroids,1)
    centroids(i, :) = mean(X(idx == i, :));
end

矢量化:

% split X into cells by idx
A = accumarray(idx, (1:N)', [nc,1], @(i) {X(i,:)});
% mean of each cell
C = cell2mat(cellfun(@(x) mean(x,1),A,'UniformOutput',0));

方法之间的最大绝对错误:

max(abs(C(:) - centroids(:))) % about 1e-16

最新更新