Code Vectorization in custom layer

1 visualización (últimos 30 días)
Rui Xiang
Rui Xiang el 11 de Abr. de 2018
Comentada: Rui Xiang el 16 de Abr. de 2018
Hi, we are designing a custom layer where we need to calculate the back-derivative from a 4D matrix
Here is a simple way using for loop to implement it
X = zeros(2,2,2,2);
X([1 5 7 10 12 14 16]) = rand(7,1);
kernelsize=5;
A=cell(2,1);
A{1}=rand(2,5);
A{2}=rand(2,5);
f=cell(2,1);
f{1}=rand(2,1);
f{2}=rand(2,1);
k = find(X);
[ii, jj, kk, ll] = ind2sub( size(X), k);
Z=zeros(size(X));
dLdW=zeros(2,5,2);
for j=1:kernelsize
for i=1:length(k)
Z(k(i))=X(k(i))*dot(A{jj(i)}(:,j),f{jj(i)});
end
sol=sum(Z,2);
dLdW(:,j,:)=sum(sol,4);
Z=zeros(size(X));
end
Is there a way to not use for loop? Because I want to use GPU to train it.

Respuesta aceptada

Joss Knight
Joss Knight el 15 de Abr. de 2018
Adotf = cellfun(@(aa,ff)ff.'*aa, A, f, 'UniformOutput', false);
Adotf = cat(1, Adotf{:});
Z = X(k).*Adotf(jj,:);
j = repmat(1:kernelsize, numel(ii), 1);
ii = repmat(ii, 1, kernelsize);
kk = repmat(kk, 1, kernelsize);
dLdW = accumarray([ii(:), j(:), kk(:)], Z(:), [size(X,1) kernelsize, size(X,3)]);
Are all the A matrices and f vectors the same size? Because if so you shouldn't use a cell array, you should concatenate in dim 3 and use pagefun instead of cellfun (if you're using gpuArray).
A = cat(3, A{:});
f = cat(2, f{:});
f = shiftdim(f, -1);
Adotf = pagefun(@mtimes, f, A);
Adotf = permute(Adotf, [3 2 1]);
Z = X(k).*Adotf(jj,:);
j = repmat(1:kernelsize, numel(ii), 1);
ii = repmat(ii, 1, kernelsize);
kk = repmat(kk, 1, kernelsize);
dLdW = accumarray([ii(:), j(:), kk(:)], Z(:), [size(X,1) kernelsize, size(X,3)]);
  2 comentarios
Rui Xiang
Rui Xiang el 16 de Abr. de 2018
They are not the same size. That's actually the biggest difficulty for me
Rui Xiang
Rui Xiang el 16 de Abr. de 2018
Thanks very much for you help:)

Iniciar sesión para comentar.

Más respuestas (0)

Categorías

Más información sobre GPU Computing en Help Center y File Exchange.

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by