implementation of mini-batch stochastic gradient descent

15 visualizaciones (últimos 30 días)
konoha
konoha el 28 de Mzo. de 2021
Respondida: Mohamed Salem el 22 de Dic. de 2022
I implemented a mini-batch stochastic gradien descent but counldn't find the bug in my code.
I used this implement to do a classification problem but all my final predictions are 0.
W2 = -1+2*rand(5,2); W3 = -1+2*rand(5,5);
W4 = -1+2*rand(5,5); W5 = -1+2*rand(1,5);
b2 = -1+2*rand(5,1); b3 = -1+2*rand(5,1);
b4 = -1+2*rand(5,1); b5 = -1+2*rand(1,1);
eta = 5e-3; % learning rate
iter = 1000; % number of iterations
num_data = length(label);
loss_vec = zeros(1,iter);
tloss_vec = zeros(1,iter);
for it = 1:iter
% mini-batch method
batch_size = 50;
rand_idx = randperm(num_data);
rand_idx = reshape(rand_idx,[],num_data/batch_size);
for idx = rand_idx
% forward pass
a2 = activate([x1(:,idx);x2(:,idx)], W2, b2);
a3 = activate(a2,W3,b3);
a4 = activate(a3,W4,b4);
a5 = activate(a4,W5,b5);
% backward pass (gradient)
delta5 = a5.*(1-a5).*(a5-label(idx));
delta4 = a4.*(1-a4).*(W5'*delta5);
delta3 = a3.*(1-a3).*(W4'*delta4);
delta2 = a2.*(1-a2).*(W3'*delta3);
% update weights and bias
W2 = W2 - 1/length(idx)*eta*delta2*[x1(:,idx);x2(:,idx)]';
W3 = W3 - 1/length(idx)*eta*delta3*a2';
W4 = W4 - 1/length(idx)*eta*delta4*a3';
W5 = W5 - 1/length(idx)*eta*delta5*a4';
b2 = b2 - 1/length(idx)*eta*sum(delta2,2);
b3 = b3 - 1/length(idx)*eta*sum(delta3,2);
b4 = b4 - 1/length(idx)*eta*sum(delta4,2);
b5 = b5 - 1/length(idx)*eta*sum(delta5,2);
% compute train loss and test loss
loss_vec(it) = 1/(2*num_data)*LossFunc(W2,W3,W4,W5,b2,b3,b4,b5,[x1;x2],label);
tloss_vec(it) = 1/(2*200)*LossFunc(W2,W3,W4,W5,b2,b3,b4,b5,[tx1;tx2],tlabel);
end
end
%% cost function
function loss = LossFunc(W2,W3,W4,W5,b2,b3,b4,b5,x,y)
a2 = activate(x, W2, b2);
a3 = activate(a2, W3, b3);
a4 = activate(a3, W4, b4);
a5 = activate(a4, W5, b5);
loss = norm(a5-y,2)^2;
end
%% prediction
function pred = predict(W2,W3,W4,W5,b2,b3,b4,b5,x)
a2 = activate(x, W2, b2);
a3 = activate(a2, W3, b3);
a4 = activate(a3, W4, b4);
a5 = activate(a4, W5, b5);
pred = round(a5);
end
%% activation function
function y = activate(x,W,b)
y = 1./(1+exp(-(W*x+b)));
end

Respuestas (2)

Mahesh Taparia
Mahesh Taparia el 2 de Abr. de 2021
Hi
You mentioned that you are implementing a classification network. In your code, you are using square of L2 norm to calculate the loss and loss derivative is also not correct while doing back propagation. Moreover, since it is a classification network, use the classification loss like cross entropy loss, focalcrossentropy, etc instead of norm. May be this is the reason you are getting 0 everytime.
Also, you can use MATLAB inbuilt function to perform back propagation. For this, you can refer the link given below:
Hope it will help!
  1 comentario
konoha
konoha el 2 de Abr. de 2021
Editada: konoha el 2 de Abr. de 2021
the derivative of mes is -(y-f(x))f'(x). I don't follow your suggestions.
Thank you.

Iniciar sesión para comentar.


Mohamed Salem
Mohamed Salem el 22 de Dic. de 2022
Write a MATLAB code, that implement Dalta learning rule with mini-batch.
Compare (with graph) your mini-batch algorithm with SGD, Batch algorithm in terms of mean square error.

Categorías

Más información sobre Deep Learning Toolbox en Help Center y File Exchange.

Productos


Versión

R2020b

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by