複数の入力データを用​いたLSTMによる学​習と,予測テストのた​めの入力データの形に​ついて

19 visualizaciones (últimos 30 días)
Yuuki
Yuuki el 8 de Mayo de 2020
Comentada: Yuuki el 11 de Mayo de 2020
LSTMを用いた時系列データの学習と予測について質問です.
学習モデルを構築したあとの,predict関数による予測とその出力に問題が発生しており,解決策が思いつかないため質問を設けさせていただきました.
学習の概要は以下の通りです.
・ 学習のための入力データA,B,Cはそれぞれ5×200 double型の行列データであり,それぞれの1列は1タイムステップに相当する.
・ A,B,Cを入力データとして学習させたいため,trainNetworkに入力する際は3×1 cellのデータを作成し入力する.(各cellにはA,B,Cのデータが入っている)
・ tを入力しt+1を出力するように学習を行う.
以上の内容を実際に書いたものが以下になります.
% Load data
XTrain_A = Data_A(:,1:end-1);
YTrain_A = Data_A(:,2:end);
XTrain_B = Data_B(:,1:end-1);
YTrain_B = Data_B(:,2:end);
XTrain_C = Data_C(:,1:end-1);
YTrain_C = Data_C(:,2:end);
% Prepare for input dataset
XTrain = {XTrain_A;XTrain_B;XTrain_C};
YTrain = {YTrain_A;YTrain_B;YTrain_C};
% Layers and options
numFeatures = 5;
numResponses = 5;
numHiddenUnits = 100;
layers = [ ...
sequenceInputLayer(numFeatures)
flattenLayer('Name','flatten')
lstmLayer(numHiddenUnits,'OutputMode','sequence')
fullyConnectedLayer(numResponses)
regressionLayer];
options = trainingOptions('sgdm', ...
'MaxEpochs',200, ...
'GradientThreshold',1, ...
'InitialLearnRate',0.005, ...
'LearnRateSchedule','piecewise', ...
'LearnRateDropPeriod',100, ...
'LearnRateDropFactor',0.3, ...
'Verbose',0, ...
'Plots','training-progress');
% Machine learning
net = trainNetwork(XTrain,YTrain,layers,options);
ここで作成した学習モデルを評価するため,A,B,Cと同様の形のDというデータセットとpredict関数を用い以下のように予測を行ったところ,局所解に陥ったような結果(同じ一定の値を最後まで出力し続ける)になりました.
予測のためのデータセットはDのみであるため,入力の形はcellではなく5×200 doubleそのままになっています
% Load data
Test_D = Data_A(:,1:end-1);
for i = 1:199;
Pred_D(:,i+1) = predict(net,Test_D(:,i),'ExecutionEnvironment','cpu');
end
これは入力データや学習の仕方もしくは予測を行う際の入力の仕方でなにか間違えてしまっているのでしょうか.
(A,B,C,Dのそれぞれは時間経過による値の変化率に差はあるものの,定性的には似たような変化を示すデータとなっています)
長くなりましたが,ミスしている箇所や解決策がわからず途方に暮れてしまっているため,些細なことでも構いませんのでご指摘いただければ幸いです.
  3 comentarios
Yuuki
Yuuki el 10 de Mayo de 2020
Editada: Yuuki el 10 de Mayo de 2020
Kenta様,
ご回答ありがとうございます.
> ざっとみたところ、コード自体は動いているようですし、全体的には正しく動いていそうでした。
実行したところ特にエラーも出ず,学習過程はどれもRMSEおよびLossのどちらも順調に低下していました
まずこれまでやっとこととして,Aという時系列データのみを用いてまず学習させたあと,同じくAの1ステップ目のみをpredict関数を用いた予測に入力し逐次予測をしてみるという再現実験を行いました.
すると実際のデータでは
  1. 5*200 double の形で入力し学習を行う
  2. 1*200 cell で各cellには5*1 doubleのデータを持つ形で入力し学習を行う
  3. 1*1 cell で各cellには5*200 doubleのデータを持つ形で入力し学習を行う(最終的には3*1 cellにし3つの入力データを学習させたい)
の3パターンで学習を行い,その後同様に上記の方法で予測をしてみたところ,2が最も十分予測できている出力が得られ,しかし一番実現したい3の手法では一定値を出力し続けるような的はずれな出力結果となりました.(=追記=を参照)
> predict and updateの形で、進めていくとうまくいくかもしれませんが、それはもうお試しになりましたか?
2においてはpredict関数で十分な結果を得られたことから一旦predict関数で統一して予測を行っています.
また,predictAndUpdateState関数も用いて予測を行ったこともありますが,結果は改善されませんでした.
(predict関数の方がより所望の結果を出力していたような気がします)
> 何らかの方法を使って、データを共有していただくことは可能ですか?
こちらに関しましては情報等の観点から実際のデータをお渡しすることはかないませんので,以下に実際のものと似た傾向を示すデータを作成し,同様の設定で学習および予測を行うコードでしたらお渡しできます.
(26行目のswitch_dataを切り替えることにより各種実行可能です.)
パッと書いただけのものなのであまり良い結果もでませんし,3.のように局所解に陥ったような予測結果を出してしまっていますが,おおよそのイメージとしてはこのような形のデータになります.
以上より,入力の形状が違うとどうして学習結果も異なるのか,そして3.はなぜうまく学習および予測ができていないのかが現時点での疑問点になります.
= 追記 =
添付したコードの1*1cellを学習させたときの予測結果が失敗していることから気づいたのですが,
n*1cellを入力とし学習を行った場合,predictもしくはpredictAndUpdateStateによる予測はどのような入力を行えばよいのでしょうか?
ここで予測のための入力に5*200 doubleの1ステップ目という,1*1cellとは違う形状を用いていることにより正しく予測できていないのではと思いました.
Kenta
Kenta el 10 de Mayo de 2020
こんにちは、丁寧に教えていただきありがとうございます。回答のほうもご覧ください。コードもありがとうございます。
1)について、3通りの方法を詳しく教えていただきありがとうございます。私は、3.の方法が適していると思います。すでにご覧になっていると思いますが、こちら の例もすでにご覧になっているとは思いますが、例えば、100個のデータセットがあって、それぞれの実験で、19個のセンサー、280のタイムステップがあると、100個のセルのうちの1つに19*280のdouble型の変数が格納されています。
2)について、丁寧にありがとうございます。いただいたコードでは、2の方法でうまくいったので、こちらで解決できないかと願っています...
3)について、こちらもありがとうございます。1)の1.と3.はデータの組が1つのときは同じことを言っていると思います。こちらで両方試しましたが結果は同じになりました。
予測に関しては、添付いただいたコードの方では、t-1のデータ(例えばセンサーデータ5つ分)からtのデータを予測しています。例えば、データの形状がy=sinxのような形だと、t-1の値が0.5でも、そこでの微分値がないとtの値はわかりませんね。それと同じ理由で、LSTMのように時系列データ用の関数に読み込ませないと履歴を蓄積できないのでうまく予測できない、ということなのではないでしょうか?(たぶん、勘違いしていたらすいません...)
補足もありがとうございます。
実装は回答に貼り付けたのですが、これでいかがでしょうか?

Iniciar sesión para comentar.

Respuesta aceptada

Kenta
Kenta el 10 de Mayo de 2020
X = XTrain{1};
numTimeSteps = size(X,2);
for i = 1:numTimeSteps
v = X(:,i);
[net,score] = predictAndUpdateState(net,v);
scores(:,i) = score;
end
figure;
for i=1:5
subplot(2,3,i)
plot(1:99,x(i,:));hold on
plot(1:99,scores(i,:));hold off
title(sprintf('the variable to predict (%d)',i))
legend('xdata','predicted data')
end
こんにちは、コメントのほう、ありがとうございます。上のようにして予測していくとうまくフィッティングできますがいかがでしょうか。もちろん、XTrainで学習・予測をしているので、良い結果が出ても不思議はないのですが、ひとまず上のようなコードで実行してもうまくいかないでしょうか。
  3 comentarios
Kenta
Kenta el 11 de Mayo de 2020
こんにちは、うまくいったようで良かったです。手元のデータでもうまくフィッティングするといいですね!丁寧に状況説明していただき、こちらもわかりやすかったです。
データ以外で手を加えるとすると、
1)学習率を変える、2)オプティマイザ(adam,sgdmなど)を変える、3)層を深くする、4)L2正則化項を変える、などがあります。下のコードのコメントを見てください。
また、今回のLSTMではなく、biLSTM(双方向LSTM)というのもあります。
前から後ろの方向に加え、後ろの情報もあればわかりやすいとき(文章の判断など)は効果がありそうです。これも簡単に実行できます。
% Layers and options
numFeatures = 5;
numResponses = 5;
numHiddenUnits = 100; %精度が出ないとき、上げるといいかも
layers = [ ...
sequenceInputLayer(numFeatures)
flattenLayer('Name','flatten')
lstmLayer(numHiddenUnits,'OutputMode','sequence')
fullyConnectedLayer(numResponses)%精度が出ないとき、fully...を2つにしてもいいかも、下に例
% fullyConnectedLayer(20)
% fullyConnectedLayer(numResponses)
regressionLayer];
options = trainingOptions('sgdm', ...% adamにしたらいいかも
'MaxEpochs',200, ...% 学習がまだ進みそうだったら増やす
'GradientThreshold',1, ...% 重みが発散しないように抑える
'InitialLearnRate',0.005, ...% はじめは大きくして、だんだん小さくする
'LearnRateSchedule','piecewise', ...
'LearnRateDropPeriod',100, ...%学習が進むにつれて学習率を小さくする
%L2regulariztionなるものが設定できる(L2正則)
'LearnRateDropFactor',0.3, ...
'Verbose',0, ...
'Plots','training-progress');
WEBで探したところ、こちらにわかりやすい記事がありました。
とはいえ、学習曲線などを見ながら調整するものと思うので、やはりこれを手がかりにWEBなどで調べながら、また手元のデータで学習をすすめながら、ということが肝要に思います。
こちらの画像+LSTMのファイルも助けになるかもしれませんが、関係ないかもしれません...
Yuuki
Yuuki el 11 de Mayo de 2020
Kenta様,
具体例や関連記事までご紹介いただきありがとうございます.
Kenta様に教えていただいたことをもとに早速実際のデータで予測モデルを作成し,ハイパーパラメータやネットワーク構造を調節しながら改善していきたいと思います.
良い結果が得られた際にはまたどこかでご報告できればと思います.
改めましてご丁寧なご回答に心よりお礼申し上げます.
ありがとうございました.

Iniciar sesión para comentar.

Más respuestas (0)

Categorías

Más información sobre Statistics and Machine Learning Toolbox en Help Center y File Exchange.

Etiquetas

Productos


Versión

R2019b

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!