Custom kernel error with using fitrgp() - periodic / local periodic kernels
Mostrar comentarios más antiguos
I am using the fitrgp() to fit a Gaussian Regression model to generated data. For this purpose, I created a dummy example, where a periodic signal y(t) (sum of two sinusoidal as shown in the code) is provided as output and I give only the time t = 1:N as the input. It is technically working, but I do not get a good fit, even though I would expect that the customly made periodic kernel should be able to capture the problem.
I am a bit unsure also, if my current implementation of (ARD squared exponential, periodic and local periodic) is correct. (Please see the function provided below.) I tested the ARD squared exponential on other multidimensional input data and it seems to work, my suspicion that there is a glitch somewhere in the periodic kernel.
The code below is running alone, so just copy & pase it in the terminal.
Any help would be very appreciated. Thanks!
clear all; clc;
%% Setup system
% create time and define period
t = 1:5000;
period = 115/2;
% output
y = sin(2*pi*t/period) + 0.8*sin(2*pi*t/period/2) + 0.1*randn(1,size(t,2));
% input
z = t;
%z = mod(t,period)/100;
plot(y)
%% Train GPs
% training set length
n = 750;
% Initial sigma value
sigma0 = std(y');
% Initial parameter values
theta0 = 0.1*ones(size(z,1)+2,1);
theta0(end) = 115;
theta0(end-1) = 0.1;
offset = 0;
tic
% | 1:ARD squared exponential | 2:periodic | 3:local periodic |
gps = fitrgp((z(:,1 + offset: n + offset))',y(1,1 + offset: n + offset)','OptimizeHyperparameters','auto',...
'KernelFunction',@(xq,xp,theta) custom_kernel(xq,xp,theta,2),'kernelparameters',theta0,'BasisFunction','none',...
'HyperparameterOptimizationOptions',struct('UseParallel',true,'MaxObjectiveEvaluations',20,...
'Optimizer','bayesopt'),'Sigma',sigma0,'Standardize',1);
toc
Plotting 1-step predictions:
np = size(z,2)-n-1; % number of predictions
figure
subplot(2,1,1)
respred1 = resubPredict(gps);
plot(y(1,1:n));
hold on
plot(respred1)
title('1 step prediction - training data','interpreter','latex')
leg = legend('Data','Model');
set(leg,'Interpreter','latex');
ylabel('Amplitude [$-$]','interpreter','latex')
%
subplot(2,1,2)
[respred1,~,ress_ci] = predict(gps, (z(:,n:n+np))');
plot(y(1,n:(n+np)));
hold on
plot(respred1)
hold on
ciplot(ress_ci(:,1),ress_ci(:,2))
title('1 step prediction - validation data','interpreter','latex')
leg = legend('Data','Model');
set(leg,'Interpreter','latex');
xlabel('Time [10 s]','interpreter','latex')
ylabel('Amplitude [$-$]','interpreter','latex')
gps.KernelInformation.KernelParameters(:)
Custom kernel function:
function k_xx = custom_kernel(xp,xq,theta,mode)
% theta structure: theta(end) : overall variance
% theta(1:end-1): length scales
% theta = [L_se, L_p, sigma, P]
% ARD squared exponential
if mode == 1
k_xx = (theta(end-1).^2)*exp( -0.5*(pdist2(xp,xq,'seuclidean',theta(1:end-2).^2)).^2 );
% Local periodic
elseif mode == 0
k_xx = (theta(end-1).^2) * exp( -0.5*(pdist2(xp(:,1:end-1),xq(:,1:end-1),'seuclidean',theta(1:end-3).^2)).^2 )...
.* exp( -2* sin( (pi/theta(end))* (xp(:,end)-xq(:,end)') ).^2 * (1/theta(end-2)).^2 );
% Periodic
elseif mode == 2
k_xx = (theta(end-1).^2) * exp(-2* sin( (pi/theta(end))* abs(xp-xq') ).^2 * (1/theta(end-2)).^2 );
end
end
1 comentario
Dominik Friml
el 26 de Mayo de 2022
Hi, I am quite the opposite of an expert on this problem, but when I stumbled on it and tried to solve it. I expanded on the answer posted to a different question. According to the Introduction to Gaussian processes by David MacKay, the definition is quite different from what you posted.
%20from%20MacKay,%20David%20JC.%20%22Introduction%20to%20Gaussian%20processes.png)
I present this code, which seems to work for me. Any improvement ideas?
%% Example data
rng(0,'twister'); close all; clear all
x = (1:5000)';
period = 115/2;
y = sin(2*pi*x/period) + 0.8*sin(2*pi*x/period/2) + 0.1*randn(size(x,1),1);
N=500;
X=[x(1:N), x(1:N)];
Y=y(1:N);
sigma0 = std(Y);
D = size(X,2);
%% Initial values of the _unconstrained_ kernel parameters
theta = 0.1;
r(1) = 1/50;
r(2) = 1/25;
lambda(1) = 0.2;
lambda(2) = 0.2;
Theta0 = [theta; r(1); r(2); lambda(1); lambda(2)];
%% Fit the model using custom kernel function
%gpr = fitrgp(X,Y,'kernelfunction',@mykernel,'kernelparameters',theta0,'verbose',1)
gpr = fitrgp(X,Y,'OptimizeHyperparameters','auto',...
'kernelfunction',@mykernel,'kernelparameters',Theta0,'BasisFunction','none',...
'HyperparameterOptimizationOptions',struct('UseParallel',true,'MaxObjectiveEvaluations',10,...
'Optimizer','bayesopt'),'Sigma',sigma0,'Standardize',1);
%% Plot results
figure
plot(x(1:1000),y(1:1000),'k');
hold on;
plot(X(:,1),Y,'r');
hold on;
[respred1,~,ress_ci] = predict(gpr,[x(1:1000), x(1:1000)]);
plot(x(1:1000),respred1,'b')
hold on
ciplot(ress_ci(:,1),ress_ci(:,2))
%% Display kernel parameters
Theta0
ThetaHat = gpr.KernelInformation.KernelParameters
function KMN = mykernel(XM,XN,Theta)
%% mykernel - Compute sum of squared exponential and squared exponential ARD.
% KMN = mykernel(XM,XN,theta) takes a M-by-D matrix XM, a N-by-D matrix
% XN and computes a M-by-N matrix KMN of kernel products such that
% KMN(i,j) is the kernel product between XM(i,:) and XN(j,:). theta is
% the R-by-1 unconstrained parameter vector for the kernel.
D = size(XM,2);
%% 1. Convert theta into parameters.
params = exp(Theta);
theta = params(1);
r = params(2:3)';
lambda = params(4:5)';
%% 2. Create the contribution due to periodic.
KMN = ((sin(pi/lambda(1)*pdist2(XM(:,1), XN(:,1)))/lambda(1))/r(1)).^2;
for i = 2:D
KMN = KMN + ((sin(pi/lambda(i)*pdist2(XM(:,i), XN(:,i)))/lambda(i))/r(i)).^2;
end
KMN = (theta^2)*exp(-0.5*KMN);
%% 4. Add the contribution due to squared exponential. (uncomment for
% locally periodic kernel)
% KMN = KMN + (sigmaF2^2)*exp(-0.5*(pdist2(XM/sigmaL2, XN/sigmaL2).^2));
% KMN = (sigmaF2^2)*exp(-0.5*(pdist2(XM/sigmaL2, XN/sigmaL2).^2));
% KMN = (sigmaF2^2)*exp(-2 * (sin(pi*abs(XM-XN)/p))*(sin(pi*abs(XM-XN)/p))' /sigmaL2^2);
end
Respuestas (0)
Categorías
Más información sobre Gaussian Process Regression en Centro de ayuda y File Exchange.
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!