Issue with obtaining gradient from Qvalue critic using dlfeval,dl​gradient,d​larrays

4 visualizaciones (últimos 30 días)
I'm trying to implement a custom agent, and inside my agent I'm running into issues with obtaining the gradient of the Q value with respect to my actor network parameters. I have my code below, main issues is with both learnimp and actorupdate functions. Inside actorupdate I call various disp functions to confirm they are of dlarrays but when I feed Qtotal into dlgradient, I get the following error
Error using dlarray/dlgradient (line 115)
'dlgradient' inputs must be traced dlarray objects or cell arrays, structures or tables containing traced dlarray objects. To enable tracing, use 'dlfeval'.
classdef CustomDDPGAgent < rl.agent.CustomAgent
properties
%actor NN
actor
%critic for tracking target
critic_track
%critic for obstacle avoidance
critic_obstacle
%dimensions
statesize
end
methods
%constructor function
function obj = CustomDDPGAgent(ActorNN,Critic_Track,Critic_Obst,statesize,actionsize)
%(required) call abstract class constructor
obj = obj@rl.agent.CustomAgent();
%define observation + action space
obj.ObservationInfo = rlNumericSpec([statesize 1]);
obj.ActionInfo = rlNumericSpec([actionsize 1],LowerLimit = -1,UpperLimit = 1);
obj.SampleTime = 0.01;
%define the actor and 2 critics
obj.actor = ActorNN;
obj.critic_track = Critic_Track;
obj.critic_obstacle = Critic_Obst;
%record observation dimensions
obj.statesize = statesize;
end
end
methods (Access = protected)
%Actor update based on Q value
function actorgradient = actorupdate(obj,Observation)
Obs_Obstacle = {dlarray([])};
for index = 1:20
Obs_Obstacle{1}(index) = Observation{1}(index);
end
disp(Observation);
disp(Obs_Obstacle);
action = evaluate(obj.actor,Observation,UseForward=true);
disp(action);
%Obtained combined Q values
Qtrack = getValue(obj.critic_track,Observation,action);
Qobstacle = getValue(obj.critic_obstacle,Obs_Obstacle,action);
Qtotal = Qtrack + Qobstacle;
Qtotal = sum(Qtotal);
disp(Qtotal);
%obtain gradient of Q value wrt parameters of actor network
actorgradient = dlgradient(Qtotal,obj.actor.Learnables);
end
%Action method
function action = getActionImpl(obj,Observation)
% Given the current state of the system, return an action
action = getAction(obj.actor,Observation);
end
%Action with noise method
function action = getActionWithExplorationImpl(obj,Observation)
% Given the current observation, select an action
action = getAction(obj.actor,Observation);
% Add random noise to action
end
%Learn method
function action = learnImpl(obj,Experience)
%parse experience
Obs = Experience{1};
%reformat in dlarrays
Obs_reformat = {dlarray(Obs{1})};
action = getAction(obj.actor,Obs_reformat);
%update actor network
ActorGradient = dlfeval(@actorupdate,obj,Obs_reformat);
end
end
end
  4 comentarios
Matt J
Matt J el 7 de Abr. de 2025
Editada: Matt J el 7 de Abr. de 2025
The functions involved with the error is the learnimp() and the actorupdate().
So you say, but you still haven't provided steps that we can run to reproduce the error.
Vincent
Vincent el 7 de Abr. de 2025
I really don't understand what you want me to do. In order for you to run it, I'd have to provide my whole project to you. Do you see any error in my code? I'm doing everything the documentation is telling me and using the disp() function I'm getting single dlarray for all my values. So hence my arguments that I pass into the dlgradient is the proper format.

Iniciar sesión para comentar.

Respuestas (1)

Matt J
Matt J el 8 de Abr. de 2025
Editada: Matt J el 9 de Abr. de 2025
I really don't understand what you want me to do. In order for you to run it, I'd have to provide my whole project to you. Do you see any error in my code?
Ideally, you would provide simplified steps and input data that reproduce the error.
In any case, as long as actorupdate() is being called only from the line,
ActorGradient = dlfeval(@actorupdate,obj,Obs_reformat);
in learnImpl(), I don't see any cause for errors. Therefore, I suspect that actorupdate() is being called in some other way. The fact that Qtotal has been verified to be a dlarray is irrelevant.
Without even simplified steps to reproduce what you are seeing, my suspicions can't be tested. However, the error messages should show you the call stack so you can verify that dlgradient was reached from dlfeval. You could also insert a call to dbstack right before the line,
actorgradient = dlgradient(Qtotal,obj.actor.Learnables);

Productos


Versión

R2024a

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by