Complex number gradient using 'dlgradient' in conjunction with neural networks

16 visualizaciones (últimos 30 días)
Hello All,
I am trying to find the gradient of a function , where C is a complex-valued constant, is a feedforward neural network, x is the input vector (real-valued) and θ are the parameters (real-valued). The output of the neural network is a real-valued array. However, due to the presence of complex constant C, the function f is becoming a complex-valued. I would like to find its gradient with respect to the input vector x.
I tried to follow the method mentioned in https://in.mathworks.com/help/deeplearning/ref/dlarray.dlgradient.html which is given below (modified)
clc;
clear all;
x = linspace(1,10,5); % Real-valued array
x = dlarray(x,"CB"); % Converting to deeplearning array
[y, grad] = dlfeval(@gradFun,x);
grad = extractdata(grad)
grad =
4.0000 - 6.0000i 13.0000 -19.5000i 22.0000 -33.0000i 31.0000 -46.5000i 40.0000 -60.0000i
% Complex-function
function y = complexFun(x)
y = (2+3j)*x.^2;
end
% Function to calculate complex gradient
function [y,grad] = gradFun(x)
y = complexFun(x);
y = real(y);
grad = dlgradient(sum(y,"all"),x,'EnableHigherDerivatives',true);
end
The method is successfully calculating the gradient of a complex number (of course, giving conjugate output). I tried implementing the same by replacing the real-valued function with . When I did this, I am encoutering the following error
"Encountered complex value when computing gradient with respect to an output of fullyconnect. Convert all outputs of fullyconnect to real".
I would be grateful if anyone could show a way to fix the error and calculate the gradients.
Thank you,
Dr. Veerababu Dharanalakota

Respuesta aceptada

Walter Roberson
Walter Roberson el 7 de Abr. de 2023
The derivative of C*f(x) can be calculated using the chain rule for multiplication: dC/dx*f(x) + C*df/dx. But when C is constant then no matter whether it is real or complex valued, dC/dx is 0. Therefore the derivative of C*f(x) is C*df/dx. The same logic applies to second derivatives.
Therefore the gradient of C*f(x) is C times the gradient of f(x). And if f(x) is real valued as indicated, and C is complex valued then unless the gradient is 0 it follows that the gradient of C*f(x) will be complex valued. Which dlgradient will refuse to work with.
So take the dlgradient of f(x) and multiply the result by C. That should at least postpone the problem.
  2 comentarios
Dr. Veerababu Dharanalakota
Dr. Veerababu Dharanalakota el 7 de Abr. de 2023
Thank you, Walter.
I am able to calculate the gradient as well as and constructed a function , where α is a real-valued constant. Now, I would like to minimize g with respect to the parameters θ. Since g is a complex-valued function, I split it into real and imaginary parts, and casted the loss function and its gradient with respect to the parameters θ as follows
g = fxx+alpha*f;
gr = real(g); % Real-part of g
gi = imag(g); % Imaginary-part of g
zeroTarget_r = zeros(size(gr),"like",gr); % Zero targets for the real-part
loss_r = l2loss(gr, zeroTarget_r); % Real-part loss function
zeroTarget_i = zeros(size(gi),"like",gi); % Zero targets for the imaginary-part
loss_i = l2loss(gi, zeroTarget_i); % Imaginary-part loss function
loss = loss_r+loss_i; % Total loss function (real-valued)
gradients = dlgradient(loss,parameters);
The loss function as wells as the parameters are real-valued. Ideally, I should be able to calculate the gradients. However, it is thowing a similar error (instead of outputs, now it is inputs)
"Encountered complex value when computing gradient with respect to an input to fullyconnect. Convert all inputs to fullyconnect to real".
I checked indivial loss values and the parameter values. They are purely real.
I hope your insight might help to resolve this issue as well.

Iniciar sesión para comentar.

Más respuestas (0)

Categorías

Más información sobre Deep Learning Toolbox en Help Center y File Exchange.

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by