How to speed up vectorized operations for dynamic programming
6 visualizaciones (últimos 30 días)
Mostrar comentarios más antiguos
Alessandro
el 14 de Sept. de 2024
Editada: Alessandro
el 14 de Sept. de 2024
I would like to speed up the following code which solves a discrete dynamic programming problem using the method of successive approximations, as described e.g. in Bertsekas.
The algorithm is made of two steps. In step 1, I precompute the payoff array R(a',a,z) where a' is the action and (a,z) are the states. In step 2 I compute the value function using the method of successive approximations: I guess V0, then I compute an updated V1 and finally I check if ||V1-V0|| is less than a tolerance level. If it is, I stop, otherwise I set V0=V1 and go on.
I profiled the code (see below a MWE) and the two most time-consuming lines are the following ones:
(1) RHS = Ret+beta*permute(EV,[1,3,2]);
(2) [max_val,max_ind] = max(RHS,[],1);
Line (1) takes up 63% of the total running time, line (2) takes 32%. As you can see I have already vectorized all loops.
I would be very grateful for any suggestion. I post below a MWE. (Note that I set n_a, the num of grid points, to a low value on purpose, to allow interested users to run quickly the example. In my actual code, n_a=10000 or more).
%% Solve income fluctuation problem CPU
clear;clc;close all
%% Economic parameters
sigma = 2;
r = 0.03;
beta = 0.96;
PZ = [0.60 0.40;
0.05 0.95];
z_grid = [0.5 1.0]';
n_z = length(z_grid);
b = 0; %lower bound for asset holdings a
grid_max = 4;
n_a = 500; % IN PRACTICE THIS IS EQUAL TO 5000-10000
R = 1+r;
a_grid = linspace(-b,grid_max,n_a)';
if sigma==1
fun_u = @(c) log(c);
else
fun_u = @(c) c.^(1-sigma)/(1-sigma);
end
%% Computational parameters
verbose = 0;
tiny = 1e-8; %very small positive number
tol = 1e-6; %tolerance for VFI and TI
max_iter = 500; %maximum num. of iterations for both VFI and TI
%% Start timing
tic
%% STEP 1- Precompute current payoff array R(a',a,z)
a_tomorrow = a_grid; %(a',1,1)
a_today = a_grid'; %(1,a,1)
z_today = shiftdim(z_grid,-2); %(1,1,z)
cons = (1+r)*a_today+z_today-a_tomorrow;
Ret = fun_u(cons); %size: [n_a,n_a,n_z]
Ret(cons<=0) = -inf;
%% STEP 2 - Value function iteration
iter = 1;
err = tol+1;
V0 = zeros(n_a,n_z);
while err>tol && iter<=max_iter
EV = V0*PZ'; %(a',z)
RHS = Ret+beta*permute(EV,[1,3,2]);
[max_val,max_ind] = max(RHS,[],1);
V1 = squeeze(max_val);
pol_ind_ap = squeeze(max_ind);
err = max(abs(V0(:)-V1(:)));
if verbose==1
fprintf('iter = %d, err = %f \n',iter,err)
end
iter = iter+1;
V0 = V1;
end
if err>tol
error('VFI did not converge!')
else
fprintf('VFI converged after = %d iterations \n',iter)
end
pol_ap = a_grid(pol_ind_ap);
pol_c = (1+r)*a_grid+z_grid'-pol_ap;
%% End timing
toc
%% Figures
figure
plot(a_grid,pol_c(:,1),'linewidth',2)
hold on
plot(a_grid,pol_c(:,2),'linewidth',2)
legend('Low shock','High shock','Location','NorthWest')
xlabel('asset level')
ylabel('consumption')
title('Consumption Policy Function')
figure
plot(a_grid,a_grid,'--','linewidth',2)
hold on
plot(a_grid,pol_ap(:,1),'linewidth',2)
hold on
plot(a_grid,pol_ap(:,2),'linewidth',2)
legend('45 line','Low shock','High shock','Location','NorthWest')
xlabel('Current period assets')
ylabel('Next-period assets')
title('Assets Policy Function')
2 comentarios
Respuesta aceptada
Matt J
el 14 de Sept. de 2024
Editada: Matt J
el 14 de Sept. de 2024
This might be a little faster.
betaPZtransp=beta*PZ';
tic
while err>tol && iter<=max_iter
RHS = Ret + reshape(V0*betaPZtransp,n_a,1,n_z);
V1 = max(RHS,[],1);
err = norm( V0(:)-V1(:) ,inf);
if verbose
fprintf('iter = %d, err = %f \n',iter,err)
end
iter = iter+1;
V0 = reshape(V1,n_a,n_z);
end
toc
[V1,pol_ind_ap]=max(RHS,[],1);
pol_ind_ap = reshape(pol_ind_ap, n_a,n_z);
2 comentarios
Más respuestas (0)
Ver también
Categorías
Más información sobre Creating and Concatenating Matrices en Help Center y File Exchange.
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!