複数の入力データを用いたLSTMによる学習と,予測テストのための入力データの形について
19 visualizaciones (últimos 30 días)
Mostrar comentarios más antiguos
LSTMを用いた時系列データの学習と予測について質問です.
学習モデルを構築したあとの,predict関数による予測とその出力に問題が発生しており,解決策が思いつかないため質問を設けさせていただきました.
学習の概要は以下の通りです.
・ 学習のための入力データA,B,Cはそれぞれ5×200 double型の行列データであり,それぞれの1列は1タイムステップに相当する.
・ A,B,Cを入力データとして学習させたいため,trainNetworkに入力する際は3×1 cellのデータを作成し入力する.(各cellにはA,B,Cのデータが入っている)
・ tを入力しt+1を出力するように学習を行う.
以上の内容を実際に書いたものが以下になります.
% Load data
XTrain_A = Data_A(:,1:end-1);
YTrain_A = Data_A(:,2:end);
XTrain_B = Data_B(:,1:end-1);
YTrain_B = Data_B(:,2:end);
XTrain_C = Data_C(:,1:end-1);
YTrain_C = Data_C(:,2:end);
% Prepare for input dataset
XTrain = {XTrain_A;XTrain_B;XTrain_C};
YTrain = {YTrain_A;YTrain_B;YTrain_C};
% Layers and options
numFeatures = 5;
numResponses = 5;
numHiddenUnits = 100;
layers = [ ...
sequenceInputLayer(numFeatures)
flattenLayer('Name','flatten')
lstmLayer(numHiddenUnits,'OutputMode','sequence')
fullyConnectedLayer(numResponses)
regressionLayer];
options = trainingOptions('sgdm', ...
'MaxEpochs',200, ...
'GradientThreshold',1, ...
'InitialLearnRate',0.005, ...
'LearnRateSchedule','piecewise', ...
'LearnRateDropPeriod',100, ...
'LearnRateDropFactor',0.3, ...
'Verbose',0, ...
'Plots','training-progress');
% Machine learning
net = trainNetwork(XTrain,YTrain,layers,options);
ここで作成した学習モデルを評価するため,A,B,Cと同様の形のDというデータセットとpredict関数を用い以下のように予測を行ったところ,局所解に陥ったような結果(同じ一定の値を最後まで出力し続ける)になりました.
予測のためのデータセットはDのみであるため,入力の形はcellではなく5×200 doubleそのままになっています.
% Load data
Test_D = Data_A(:,1:end-1);
for i = 1:199;
Pred_D(:,i+1) = predict(net,Test_D(:,i),'ExecutionEnvironment','cpu');
end
これは入力データや学習の仕方もしくは予測を行う際の入力の仕方でなにか間違えてしまっているのでしょうか.
(A,B,C,Dのそれぞれは時間経過による値の変化率に差はあるものの,定性的には似たような変化を示すデータとなっています)
長くなりましたが,ミスしている箇所や解決策がわからず途方に暮れてしまっているため,些細なことでも構いませんのでご指摘いただければ幸いです.
3 comentarios
Kenta
el 10 de Mayo de 2020
こんにちは、丁寧に教えていただきありがとうございます。回答のほうもご覧ください。コードもありがとうございます。
1)について、3通りの方法を詳しく教えていただきありがとうございます。私は、3.の方法が適していると思います。すでにご覧になっていると思いますが、こちら の例もすでにご覧になっているとは思いますが、例えば、100個のデータセットがあって、それぞれの実験で、19個のセンサー、280のタイムステップがあると、100個のセルのうちの1つに19*280のdouble型の変数が格納されています。
2)について、丁寧にありがとうございます。いただいたコードでは、2の方法でうまくいったので、こちらで解決できないかと願っています...
3)について、こちらもありがとうございます。1)の1.と3.はデータの組が1つのときは同じことを言っていると思います。こちらで両方試しましたが結果は同じになりました。
予測に関しては、添付いただいたコードの方では、t-1のデータ(例えばセンサーデータ5つ分)からtのデータを予測しています。例えば、データの形状がy=sinxのような形だと、t-1の値が0.5でも、そこでの微分値がないとtの値はわかりませんね。それと同じ理由で、LSTMのように時系列データ用の関数に読み込ませないと履歴を蓄積できないのでうまく予測できない、ということなのではないでしょうか?(たぶん、勘違いしていたらすいません...)
補足もありがとうございます。
実装は回答に貼り付けたのですが、これでいかがでしょうか?
Respuesta aceptada
Kenta
el 10 de Mayo de 2020
X = XTrain{1};
numTimeSteps = size(X,2);
for i = 1:numTimeSteps
v = X(:,i);
[net,score] = predictAndUpdateState(net,v);
scores(:,i) = score;
end
figure;
for i=1:5
subplot(2,3,i)
plot(1:99,x(i,:));hold on
plot(1:99,scores(i,:));hold off
title(sprintf('the variable to predict (%d)',i))
legend('xdata','predicted data')
end
こんにちは、コメントのほう、ありがとうございます。上のようにして予測していくとうまくフィッティングできますがいかがでしょうか。もちろん、XTrainで学習・予測をしているので、良い結果が出ても不思議はないのですが、ひとまず上のようなコードで実行してもうまくいかないでしょうか。
3 comentarios
Kenta
el 11 de Mayo de 2020
こんにちは、うまくいったようで良かったです。手元のデータでもうまくフィッティングするといいですね!丁寧に状況説明していただき、こちらもわかりやすかったです。
データ以外で手を加えるとすると、
1)学習率を変える、2)オプティマイザ(adam,sgdmなど)を変える、3)層を深くする、4)L2正則化項を変える、などがあります。下のコードのコメントを見てください。
また、今回のLSTMではなく、biLSTM(双方向LSTM)というのもあります。
前から後ろの方向に加え、後ろの情報もあればわかりやすいとき(文章の判断など)は効果がありそうです。これも簡単に実行できます。
% Layers and options
numFeatures = 5;
numResponses = 5;
numHiddenUnits = 100; %精度が出ないとき、上げるといいかも
layers = [ ...
sequenceInputLayer(numFeatures)
flattenLayer('Name','flatten')
lstmLayer(numHiddenUnits,'OutputMode','sequence')
fullyConnectedLayer(numResponses)%精度が出ないとき、fully...を2つにしてもいいかも、下に例
% fullyConnectedLayer(20)
% fullyConnectedLayer(numResponses)
regressionLayer];
options = trainingOptions('sgdm', ...% adamにしたらいいかも
'MaxEpochs',200, ...% 学習がまだ進みそうだったら増やす
'GradientThreshold',1, ...% 重みが発散しないように抑える
'InitialLearnRate',0.005, ...% はじめは大きくして、だんだん小さくする
'LearnRateSchedule','piecewise', ...
'LearnRateDropPeriod',100, ...%学習が進むにつれて学習率を小さくする
%L2regulariztionなるものが設定できる(L2正則)
'LearnRateDropFactor',0.3, ...
'Verbose',0, ...
'Plots','training-progress');
WEBで探したところ、こちらにわかりやすい記事がありました。
とはいえ、学習曲線などを見ながら調整するものと思うので、やはりこれを手がかりにWEBなどで調べながら、また手元のデータで学習をすすめながら、ということが肝要に思います。
こちらの画像+LSTMのファイルも助けになるかもしれませんが、関係ないかもしれません...
Más respuestas (0)
Ver también
Categorías
Más información sobre Statistics and Machine Learning Toolbox en Help Center y File Exchange.
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!