Optimising Nearest Neighbor Program
3 visualizaciones (últimos 30 días)
Mostrar comentarios más antiguos
Samuel Leeney
el 30 de Oct. de 2020
Respondida: Bruno Luong
el 31 de Oct. de 2020
Hi Guys,
I am trying to optmimise this code so that it runs in under 10 seconds for N=20k.
It currently takes around 40 seconds to run.
I think that I need to vecotirse some or all of the loops so that the calculations are done at the same time, but I cannot figure out how to do it.
Any help would be much appreciated as we are learning from home with next to no support from the University.
Here is the code.
s = zeros(N,N);
for n = 1:3
for a = 1:N
count = 0;
for b=1:N
off=0;
if (pos(n,a)<=0.25 && pos(n,b) >= 0.75)
off=1;
elseif pos(n,a)>=0.75 && pos(n,b)<=0.25
off=-1;
end
s(a,b)=s(a,b)+(pos(n,a)-pos(n,b)+off)^2.0;
if n == 3
s(a,b)=sqrt(s(a,b));
end
end
end
end
match = zeros(1,N);
for a=1:N
mindist=1e10;
for c=1:N
if (a~=c)
mindist=min(s(a,c),mindist);
if (mindist==s(a,c))
match(a)=c;
end
end
end
end
4 comentarios
KSSV
el 30 de Oct. de 2020
Refer this, it might help you: https://in.mathworks.com/matlabcentral/fileexchange/44334-nearest-neighboring-particle-search-using-all-particles-search
Respuesta aceptada
Bruno Luong
el 31 de Oct. de 2020
I put my comments as answer here, so you can accept if it helps
N=20000;
nd=3;
pos=rand(nd,N);
pos_r = reshape(pos,[nd 1 N]);
s = zeros(N);
for b=1:N
posb = pos_r(:,:,b);
off = (pos<=0.25 & posb>=0.75) - ...
(pos>=0.75 & posb<=0.25);
sb = sum((pos-posb+off).^2,1);
s(:,b) = sb(:);
end
s = sqrt(s);
s(1:N+1:end) = Inf;
[~,match] = min(s,[],2);
match = match.'; % row vector
Still if one dosn't have to deal with the odd OFFSET, the delaunay approach is much faster.
0 comentarios
Más respuestas (3)
KSSV
el 30 de Oct. de 2020
Editada: KSSV
el 30 de Oct. de 2020
N = 1000;
seed = 1234;
%seed=input('Random number seed = ');
rng(seed)
pos=rand(3,N);
a = 1:N ;
b = 1:N ;
[a,b] = meshgrid(a,b) ;
S = zeros(N,N) ;
for n = 1:3
off = zeros(N,N) ;
P = pos(n,:) ;
off(P(a)<=0.25 && P(b) >= 0.75) = +1 ;
off(P(a)>=0.75 && P(b)<=0.25) = -1 ;
S = S+(P(a)-P(b)+off).^2;
end
S = sqrt(S);
9 comentarios
KSSV
el 30 de Oct. de 2020
How did you calculate the time? Check the below:
N = 1000;
seed = 1234;
%seed=input('Random number seed = ');
rng(seed)
pos=rand(3,N);
t1 = tic ;
s = zeros(N,N);
for n = 1:3
for a = 1:N
count = 0;
for b=1:N
off=0;
if (pos(n,a)<=0.25 && pos(n,b) >= 0.75)
off=1;
elseif pos(n,a)>=0.75 && pos(n,b)<=0.25
off=-1;
end
s(a,b)=s(a,b)+(pos(n,a)-pos(n,b)+off)^2.0;
if n == 3
s(a,b)=sqrt(s(a,b));
end
end
end
end
t1 = toc(t1) ;
t2 = tic ;
a = 1:N ;
b = 1:N ;
[a,b] = meshgrid(a,b) ;
S = zeros(N,N) ;
for n = 1:3
off = zeros(N,N) ;
P = pos(n,:) ;
off(P(a)<=0.25 && P(b) >= 0.75) = +1 ;
off(P(a)>=0.75 && P(b)<=0.25) = -1 ;
S = S+(P(a)-P(b)+off).^2;
end
S = sqrt(S);
t2 = toc(t2) ;
Now check t1, t2 for different size inputs. My bet is always the second code will be faster. I am comparing only the first part off the code.
Bruno Luong
el 30 de Oct. de 2020
Editada: Bruno Luong
el 30 de Oct. de 2020
I remove your hanling of offset (not sure what is the purpose), and this is a much faster method using delaunay triangulation:
clear
N=20000;
nd=3;
pos=rand(nd,N);
tic
s = zeros(N,N);
for n = 1:3
for a = 1:N
count = 0;
for b=1:N
s(a,b)=s(a,b)+(pos(n,a)-pos(n,b))^2.0;
if n == 3
s(a,b)=sqrt(s(a,b));
end
end
end
end
match = zeros(1,N);
for a=1:N
mindist=1e10;
for c=1:N
if (a~=c)
mindist=min(s(a,c),mindist);
if (mindist==s(a,c))
match(a)=c;
end
end
end
end
toc % Elapsed time is 56.414344 seconds.
Find nearest neighbour of the same set of point in 2D or 3D
% INPUT: pos is array of size (nd x N), coordinates of N points in R^nd
tic
T = delaunay(pos.');
p = nchoosek(1:size(T,2),2);
P = T(:,p);
P = reshape(P,[],2);
P = unique(sort(P,2),'rows');
P1 = P(:,1);
P2 = P(:,2);
d2 = sum((pos(:,P2)-pos(:,P1)).^2,1);
A = [P1(:), d2(:), P2(:);
P2(:), d2(:), P1(:)];
A = sortrows(A,[1 2]);
b = [true; diff(A(:,1),1)>0];
A = A(b,:);
nn = A(:,3).'; % index of nearest neighbour
d = sqrt(A(:,2)).'; % corresponding distance
toc % Elapsed time is 0.349317 seconds.
isequal(match,nn) % 1
Now it doesn't help you for your practice, but I still pot it here for future readers who seek for fast method.
4 comentarios
Bruno Luong
el 30 de Oct. de 2020
Editada: Bruno Luong
el 30 de Oct. de 2020
I run it 10 times and isequal(match,nn) return TRUE. So the answer match 10 times with random points.
> for k=1:10; benchnntest; end
Elapsed time is 49.127755 seconds.
Elapsed time is 0.516644 seconds.
ans =
logical
1
Elapsed time is 51.189459 seconds.
Elapsed time is 0.459055 seconds.
ans =
logical
1
Elapsed time is 50.431960 seconds.
Elapsed time is 0.465887 seconds.
ans =
logical
1
Elapsed time is 50.426246 seconds.
Elapsed time is 0.454885 seconds.
ans =
logical
1
Elapsed time is 50.651649 seconds.
Elapsed time is 0.567084 seconds.
ans =
logical
1
Elapsed time is 50.889422 seconds.
Elapsed time is 0.461514 seconds.
ans =
logical
1
Elapsed time is 50.678441 seconds.
Elapsed time is 0.491820 seconds.
ans =
logical
1
Elapsed time is 50.476219 seconds.
Elapsed time is 0.451430 seconds.
ans =
logical
1
Elapsed time is 52.659327 seconds.
Elapsed time is 0.443114 seconds.
ans =
logical
1
Elapsed time is 52.004992 seconds.
Elapsed time is 0.459873 seconds.
ans =
logical
1
>>
Image Analyst
el 30 de Oct. de 2020
MATLAB is column major order, which means that the left most indexes go faster because they are adjacent in memory. MATLAB goes down rows first, then moves over to the next column and goes down its rows. So this slow code
for row = 1 : rows
for col = 1 : columns
s(row, col) = whatever; % Col iterates fastest
end
end
will (or may be) be slower than this fast code
for col = 1 : columns
for row = 1 : rows
s(row, col) = whatever; % row iterates fastest
end
end
Note that, in your code, n is your left most index of your arrays, yet you had the n loop as the outer loop, which is the slowest possible to do it. If possible, see if you can move n to an inner loop. I've had luck in the past getting nested loops to speed up doing that.
8 comentarios
Image Analyst
el 30 de Oct. de 2020
Samuel, going by your description, I'd try something like this to find the closest point.
N = 1000;
xyz = rand(N, 3); % Get N randomly located points in 3-D.
for k = 1 : N
% Get the squared distance of point k to every other point in the array.
distancesSquared = ((xyz(k, 1) - xyz(:, 1)) .^2 + ...
(xyz(k, 2) - xyz(:, 2)) .^2 + ...
(xyz(k, 3) - xyz(:, 3)) .^2);
% We don't want to consider the distance of the point to itself, so set any zeros to infinity.
distancesSquared(distancesSquared==0) = inf;
% Find the min value and the index of that min for the other points.
[minDist2, index] = min(distancesSquared);
% Print it out.
fprintf('Point %d at (%.2f, %.2f, %.2f) is closest to point %d at (%.2f, %.2f, %.2f).\n',...
k, xyz(k, 1), xyz(k, 2), xyz(k, 3), index, xyz(index, 1), xyz(index, 2), xyz(index, 3));
end
It prints out stuff like:
Point 1 at (0.70, 0.37, 0.03) is closest to point 312 at (0.70, 0.37, 0.04).
Point 2 at (0.09, 0.71, 0.91) is closest to point 918 at (0.10, 0.72, 0.94).
Point 3 at (0.53, 0.95, 0.47) is closest to point 50 at (0.54, 0.93, 0.43).
etc.
Point 998 at (0.03, 0.67, 0.34) is closest to point 972 at (0.05, 0.59, 0.39).
Point 999 at (0.99, 0.87, 0.16) is closest to point 592 at (0.94, 0.93, 0.16).
Point 1000 at (0.54, 0.34, 0.42) is closest to point 540 at (0.55, 0.33, 0.42).
It doesn't do the stuff about the edges of the square though.
Ver también
Categorías
Más información sobre Logical en Help Center y File Exchange.
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!