How to fix the error: Error using trainNetwork, Input data indices must be nonnegative integers.

4 visualizaciones (últimos 30 días)
This is a problem of sequence-to-sequnce classification. (e.g., input: (0.5, -5, 3, 10, 40, ...); prediction: (P, T, T, T, n/a,...))
I apply Tranformer encoder based on the code by Ben (Matlab staff, https://www.mathworks.com/matlabcentral/answers/2014811-is-there-any-documentation-on-how-to-build-a-transformer-encoder-from-scratch-in-matlab ), and replace LSTM layer by a Transformer encoder. The modified code by me is given at the bottom.
When I run the section of network training, I got an error message as follows, and hopefully could get some help to fix the problem.
---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
Error in waveExtractionTest_TransEnc (line ...)
filteredNet = trainNetwork(filteredTrainSignalss,trainLabels,net,options);
Caused by:
Error using nnet.internal.cnn.layer.util.EmbeddingDAGNetworkBaseStrategy/embedData
Input data indices must be nonnegative integers.
--------------------------------------------------------------------------------------------------------------------------------------------------------------------
%% Download the data
dataURL = 'https://www.mathworks.com/supportfiles/SPT/data/QTDatabaseECGData1.zip';
dirQT = pwd;
datasetFolder = fullfile(dirQT,'QTDataset');
zipFile = fullfile(dirQT,'QTDatabaseECGData.zip');
if ~exist(datasetFolder,'dir')
websave(zipFile,dataURL);
unzip(zipFile,dirQT);
end
%%
sds = signalDatastore(datasetFolder,'SignalVariableNames',["ecgSignal","signalRegionLabels"])
sds =
signalDatastore with properties: Files:{ '/users/mss.system.g87Cwy/QTDataset/ecg1.mat'; '/users/mss.system.g87Cwy/QTDataset/ecg10.mat'; '/users/mss.system.g87Cwy/QTDataset/ecg100.mat' ... and 207 more } Folders: {'/users/mss.system.g87Cwy/QTDataset'} AlternateFileSystemRoots: [0×0 string] ReadSize: 1 SignalVariableNames: ["ecgSignal" "signalRegionLabels"] ReadOutputOrientation: "column"
%%
rng default
[trainIdx,~,testIdx] = dividerand(numel(sds.Files),0.8,0,0.2);
trainDs = subset(sds,trainIdx);
testDs = subset(sds,testIdx);
%%
trainDs = transform(trainDs, @getmask);
testDs = transform(testDs, @getmask);
%%
trainDs = transform(trainDs,@resizeData);
testDs = transform(testDs,@resizeData);
%%
% Bandpass filter design
hFilt = designfilt('bandpassiir', 'StopbandFrequency1',0.4215,'PassbandFrequency1', 0.5, ...
'PassbandFrequency2',40,'StopbandFrequency2',53.345,...
'StopbandAttenuation1',60,'PassbandRipple',0.1,'StopbandAttenuation2',60,...
'SampleRate',250,'DesignMethod','ellip');
% Create tall arrays from the transformed datastores and filter the signals
tallTrainSet = tall(trainDs);
Starting parallel pool (parpool) using the 'Processes' profile ... Parallel pool using the 'Processes' profile is shutting down.
Error using tall
Could not create a mapreduce execution environment from the default parallel cluster.

Caused by:
Error using gcp
Parallel pool failed to start with the following error. For more detailed information, validate the profile 'Processes' in the Cluster Profile Manager.
Error using parallel.internal.pool.AbstractInteractiveClient>iThrowWithCause
Failed to initialize the interactive session.
Error using parallel.internal.pool.AbstractInteractiveClient>iThrowIfBadParallelJobStatus
The interactive communicating job failed with no message.
tallTestSet = tall(testDs);
filteredTrainSignals = gather(cellfun(@(x)filter(hFilt,x),tallTrainSet(:,1),'UniformOutput',false));
trainLabels = gather(tallTrainSet(:,2));
filteredTestSignals = gather(cellfun(@(x)filter(hFilt,x),tallTestSet(:,1),'UniformOutput',false));
testLabels = gather(tallTestSet(:,2));
%% Create model
% We will use 2 encoder layers.
numHeads = 1;
numKeyChannels = 20;
feedforwardHiddenSize = 100;
modelHiddenSize = 20;
% Since the values in the sequence can be 1,2, ..., 10 the "vocabulary" size is 10.
vocabSize = 100000; % the size of input sequence of one sample-training-data is 5000
inputSize = 1;
encoderLayers = [
sequenceInputLayer(1,Name="in") % input
wordEmbeddingLayer(modelHiddenSize,vocabSize,Name="embedding") % embedding
positionEmbeddingLayer(modelHiddenSize,vocabSize) % position embedding
additionLayer(2,Name="embed_add") % add the data and position embeddings
selfAttentionLayer(numHeads,numKeyChannels) % encoder block 1
additionLayer(2,Name="attention_add") %
layerNormalizationLayer(Name="attention_norm") %
fullyConnectedLayer(feedforwardHiddenSize) %
reluLayer %
fullyConnectedLayer(modelHiddenSize) %
additionLayer(2,Name="feedforward_add") %
layerNormalizationLayer(Name="encoder1_out") %
selfAttentionLayer(numHeads,numKeyChannels) % encoder block 2
additionLayer(2,Name="attention2_add") %
layerNormalizationLayer(Name="attention2_norm") %
fullyConnectedLayer(feedforwardHiddenSize) %
reluLayer %
fullyConnectedLayer(modelHiddenSize) %
additionLayer(2,Name="feedforward2_add") %
layerNormalizationLayer() %
% indexing1dLayer %
% fullyConnectedLayer(inputSize)
fullyConnectedLayer(4)
softmaxLayer("Name","softmax")
classificationLayer("Name","classification")
]; % output head
%
net = layerGraph(encoderLayers);
net = connectLayers(net,"embed_add","attention_add/in2");
net = connectLayers(net,"embedding","embed_add/in2");
net = connectLayers(net,"attention_norm","feedforward_add/in2");
net = connectLayers(net,"encoder1_out","attention2_add/in2");
net = connectLayers(net,"attention2_norm","feedforward2_add/in2");
% net = initialize(net);
% analyze the network to see how data flows through it
analyzeNetwork(net)
%
%%
options = trainingOptions("adam", ...
MaxEpochs = 10, ...
MiniBatchSize = 50, ...
Plots="training-progress", ...
Shuffle="every-epoch", ...
InitialLearnRate=1e-2, ...
LearnRateDropFactor=0.9, ...
LearnRateDropPeriod=3, ...
LearnRateSchedule="piecewise");
%%
filteredNet = trainNetwork(filteredTrainSignals,trainLabels,net,options);
%
%
%
%
%% You need the function below, getmask,
function outputCell = getmask(inputCell)
%GETMASK Convert region labels to a mask of labels of size equal to the
%size of the input ECG signal.
%
% inputCell is a two-element cell array containing an ECG signal vector
% and a table of region labels.
%
% outputCell is a two-element cell array containing the ECG signal vector
% and a categorical label vector mask of the same length as the signal.
% Copyright 2020 The MathWorks, Inc.
sig = inputCell{1};
roiTable = inputCell{2};
L = length(sig);
M = signalMask(roiTable);
% Get categorical mask and give priority to QRS regions when there is overlap
mask = catmask(M,L,'OverlapAction','prioritizeByList','PriorityList',[2 1 3]);
% Set missing values to "n/a"
mask(ismissing(mask)) = "n/a";
outputCell = {sig,mask};
end
%
%
%
%
function outputCell = resizeData(inputCell)
%RESIZEDATA Break input ECG signal and label mask into segments of length
%5000.
%
% inputCell is a two-element cell array containing an ECG signal and a
% label mask.
%
% outputCell is a two-column cell array containing as many 5000-long
% signal segments and label masks that were possible to generate from the
% input data.
% Copyright 2019 The MathWorks, Inc.
targetLength = 5000;
sig = inputCell{1};
mask = inputCell{2};
% Get number of chunks
numChunks = floor(size(sig,1)/targetLength);
% Truncate signal and mask to integer number of chunks
sig = sig(1:numChunks*targetLength);
mask = mask(1:numChunks*targetLength);
% Create a cell array containing signal chunks
sigOut = reshape(sig,targetLength,numChunks)';
sigOut = num2cell(sigOut,2);
% Create a cell array containing mask chunks
lblOut = reshape(mask,targetLength,numChunks)';
lblOut = num2cell(lblOut,2);
% Output a two-column cell array with all chunks
outputCell = [sigOut, lblOut];
end
  4 comentarios
Walter Roberson
Walter Roberson el 13 de Dic. de 2023
It looks like the signals are completely the wrong size for the network.
At some point it tries to shape one column of a 20 by 100001 to be 20 by 50 by 5000

Iniciar sesión para comentar.

Respuestas (0)

Categorías

Más información sobre AI for Signals en Help Center y File Exchange.

Productos


Versión

R2023b

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by