Contenido principal

Train PyTorch Channel Prediction Models with Online Training

Since R2026a

This example shows how to train a PyTorch® gated recurrent unit (GRU) channel prediction network online by generating each training batch in MATLAB® on the fly, enabling real‐time adaptation to time‐varying wireless channels.

Introduction

Wireless channel prediction plays a crucial role in modern communications systems by enabling more efficient and reliable data transmission. If historical data is available, neural networks can learn channel characteristics directly by using data-driven techniques. Training from historical channel data is particularly beneficial in complex interference environments that are difficult to accurately model parametrically.

Recurrent neural networks, such as long short-term memory (LSTM) and GRUs, are excellent candidates for forecasting future channel states based on past measurements. In this online training example, a GRU-based channel prediction model defined in PyTorch is trained with fresh data that is generated in MATLAB within the training loop. This workflow contrasts with offline training—where training and validation data are pre-generated—and instead emphasizes generating fresh data per iteration. For details on offline training workflow, see the Train PyTorch Channel Prediction Models example.

In this example, you train a GRU network defined in PyTorch. The nr_channel_predictor.py file contains the neural network definition, training routines, and other functionality. The nr_channel_predictor_wrapper.py file contains the interface functions. For guidance on implementing similar wrapper functions, see PyTorch Wrapper Template.

You generate data in MATLAB as described in the Preprocess Data for AI-Based CSI Prediction example and train the network in PyTorch using the Python® interface in MATLAB.

Set Up Python Environment

Before running this example, set up the Python environment as explained in Call Python from MATLAB for Wireless. Specify the full path of the Python executable to use in the pythonPath field below. The helperSetupPyenv function sets the Python environment in MATLAB according to the selected options and checks that the libraries listed in the requirements_chanest.txt file are installed. This example is tested with Python version 3.11.

if ispc
  pythonPath = ".\.venv\Scripts\pythonw.exe";
else
  pythonPath = "./venv/bin/python3";
end
requirementsFile = "requirements_chanpre.txt";
executionMode = "OutOfProcess";
currentPyenv = helperSetupPyenv(pythonPath,executionMode,requirementsFile);
Setting up Python environment
Parsing requirements_chanpre.txt 
Checking required package 'numpy'
Checking required package 'torch'
Required Python libraries are installed.

You can use the following process ID and name to attach a debugger to the Python interface and debug the example code.

fprintf("Process ID for '%s' is %s.\n",...
currentPyenv.ProcessName,currentPyenv.ProcessID)
Process ID for 'MATLABPyHost' is 85136.

Preload the Python module to reduce start time.

module = py.importlib.import_module('nr_channel_predictor_wrapper');

Initialize Data Generation

The Preprocess Data for AI-Based CSI Feedback Compression example shows data generation details for channel prediction. This example uses the following parameters for system and channel configuration.

txAntennaSize = 2;
rxAntennaSize = 2;
rmsDelaySpread = 300e-9;     % s
maxDoppler = 37;             % Hz
nSizeGrid = 52;              % Number resource blocks (RB)
subcarrierSpacing = 15;      % 15,30,60,120 kHz
numerology = (subcarrierSpacing/15)-1;

channel = nrTDLChannel;
channel.DelayProfile = 'TDL-A';
channel.DelaySpread = rmsDelaySpread;       % s
channel.MaximumDopplerShift = maxDoppler;   % Hz
channel.RandomStream = "Global stream";
channel.NumTransmitAntennas = txAntennaSize;
channel.NumReceiveAntennas = rxAntennaSize;
channel.ChannelFiltering = false;

% Carrier definition
carrier = nrCarrierConfig;
carrier.NSizeGrid = nSizeGrid;
carrier.SubcarrierSpacing = subcarrierSpacing;

% Channel prediction horizon
horizon = 2;

In this type of online training, the channel data gets generated on demand rather than being pre-generated beforehand. The helperChannelBatchGenerator function uses a helperBackgroundRunner System object™ to invoke helperChanPreBatchData asynchronously. This setup avoids pausing the training loop to wait for new samples. For more information on helperBackgroundRunner, see Background Data Generation.

Each call to helperChanPreBatchData produces one time slot of channel estimates. That is, you get estimates for every subcarrier and receive antenna (an Nsubcarriers× Nrx complex array), formatted into Ntx-by-Nsequence feature sequences. The complex valued channel estimates are stored as interleaved real-imaginary samples. The batch generator buffers these time-slot samples in a FIFO queue, shuffles them in the batch dimension, and on each call returns exactly BatchSize samples for training or validation. Because you can change dataGen. BatchSize at any time, you can seamlessly switch between smaller "online" batches for training and larger batches for periodic validation.

maxBatchSize = 10000;   % Maximum number of samples in a batch
dataMin = -2.5;         % Minimum value of generated data for normalization
dataMax = 2.5;          % Maximum value of generated data for normalization
sequenceLength = 55;    % Continuous data generation length in slots
SNR_dB = 20;

dataGen = helperChannelBatchGenerator( ...
  Channel=channel, ...
  Carrier=carrier, ...
  MaxBatchSize=maxBatchSize, ...
  SequenceLength=sequenceLength, ...
  DataMin=dataMin, ...
  DataMax=dataMax, ...
  SNRdB=SNR_dB, ...
  Horizon=horizon);

Initiate Neural Network

Initialize the channel predictor neural network. Set the GRU hidden size to 64 and the number of hidden GRUs to 2. The chanPredictor variable is the PyTorch model for the GRU-based channel predictor.

gruHiddenSize = 128;
gruNumLayers  = 2;
Ntx = txAntennaSize;
chanPredictor = py.nr_channel_predictor_wrapper.construct_model(...
  Ntx, ...
  gruHiddenSize, ...
  gruNumLayers);

py.nr_channel_predictor_wrapper.info(chanPredictor)
Model architecture:
ChannelPredictorGRU(
  (gru): GRU(4, 128, num_layers=2, batch_first=True, dropout=0.3)
  (layer_norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
  (fc): Linear(in_features=128, out_features=4, bias=True)
)

Total number of parameters: 151300

Train Neural Network

The nr_channel_predictor_wrapper.py file contains the MATLAB interface functions to train the channel predictor neural network. Set values for the hyperparameters: number of iterations, batch size, initial learning rate, and validation frequency in iterations.

In the online training loop, a new data batch is generated and then the trainer performs a weight update for each iteration. In addition, periodic validation is performed.

The output of the train_one_iteration Python function is a cell array with two elements. The output contains the following in order:

  • Training loss

  • Current learning rate

Accessing these values triggers a data transfer from Python to MATLAB. To minimize the training time but also keep track of training performance, access these parameters only during validation time.

Set the progressPlot variable to true to plot the training progress using the trainingProgressMonitor function, which requires Deep Learning Toolbox®. To print the metrics, set the verbose variable to true.

Training and testing for 500,000 iterations takes about 60 minutes on a PC that has NVIDIA® TITAN V GPU with a compute capability of 7.0 and 12 GB memory. To train the network, set trainNow to true.

trainNow = false;
if trainNow
  dltAvailable = checkDeepLearningToolbox();
  numIterations            = 50000;
  trainingBatchSize        = 512;
  validationBatchSize      = 1000;
  initialLearningRate      = 1e-3;
  learningRateUpdatePeriod = 175;
  validationFrequency      = 1000;
  progressPlot             = false;    % Set to true if you want progress plot
  verbose                  = true;

  if dltAvailable && progressPlot
    monitor = trainingProgressMonitor( ...
      Metrics=["TrainingLoss","ValidationLoss"], ...
      Info=["LearningRate","Iteration"], ...
      XLabel="Iteration");
    groupSubPlot(monitor,"Loss",["TrainingLoss","ValidationLoss"]);
    yscale(monitor,"Loss","log")
    monitor.Status = "Running";
  end

  tStartLoop = tic;
  dataGen.BatchSize = trainingBatchSize;

  % Set up the trainer for online training using the Python wrapper.
  trainer = py.nr_channel_predictor_wrapper.setup_trainer( ...
    Ntx, ...                 % number of transmit antennas
    gruHiddenSize, ...       % hidden size
    gruNumLayers, ...        % number of layers
    initialLearningRate, ... % learning rate
    trainingBatchSize, ...   % batch size
    verbose);

  % Training loop
  numOutputs = floor(numIterations/validationFrequency);
  trainingLoss = nan(1,numIterations);
  validationLoss = nan(1,numOutputs);
  iterations = nan(1,numOutputs);
  outCount = 1;
  tStart = tic;
  fprintf('Starting online training...\n');
  for iteration = 1:numIterations
    % Get the next batch from the data generator
    [inputData,targetData] = dataGen();

    % Perform one training update
    trainLoss = py.nr_channel_predictor_wrapper.train_one_iteration(trainer,inputData,targetData, ...
      ~mod(iteration,learningRateUpdatePeriod));
    trainingLoss(iteration) = trainLoss;

    % Every validationFrequency iterations, perform validation
    if mod(iteration,validationFrequency) == 0
      % Access data from Python
      learningRate = trainer.get_current_learning_rate();

      % Generate validation batch
      dataGen.BatchSize = validationBatchSize;
      [inputData,targetData] = dataGen();
      dataGen.BatchSize = trainingBatchSize;

      % Perform validation
      valLoss = py.nr_channel_predictor_wrapper.validate(trainer,inputData,targetData);

      % Report metrics
      valLoss = double(valLoss);
      if verbose
        et = seconds(toc(tStartLoop)); et.Format = "hh:mm:ss.SSS";
        fprintf('%s : [%d/%d] Training Loss: %1.3f, Validation Loss: %1.3f, Learning Rate: %1.3e\n',...
          et,iteration,numIterations,trainLoss,valLoss,learningRate);
      end
      if dltAvailable && progressPlot
        recordMetrics(monitor, ...
          iteration, ...
          TrainingLoss=trainLoss, ...
          ValidationLoss=valLoss);
        updateInfo(monitor, ...
          LearningRate=learningRate, ...
          Iteration=sprintf("%d of %d",iteration,numIterations));
        monitor.Progress = 100*iteration/numIterations;
        if monitor.Stop
          monitor.Status = "Stopped by user";
          break
        end
      else
        trainingLoss(outCount) = trainLoss;
        validationLoss(outCount) = valLoss;
        iterations(outCount) = iteration;
        outCount = outCount+1;
      end
    end
  end
  if dltAvailable && progressPlot
    monitor.Status = "Training complete";
    trainingLoss = monitor.MetricData.TrainingLoss(:,2);
    validationLoss = monitor.MetricData.ValidationLoss(:,2);
    iterationNum = monitor.MetricData.ValidationLoss(:,1);
    iterations = str2double(extractBefore(monitor.InfoData.Iteration," of"));
  end
  et = toc(tStart); et = seconds(et); et.Format = "hh:mm:ss.SSS";
  fprintf('%s : Training complete.\n',et);

Release dataGen to stop background data generation.

  release(dataGen)

Get the trained network and final validation loss.

  chanPredictor = trainer.model;
  finalValidationLoss = validationLoss(end);

Save the network for future use together with the training information.

  modelFileName = sprintf("channel_predictor_gru_horizon%d_iters%d",horizon,numIterations);
  fileName = py.nr_channel_predictor_wrapper.save( ...
    chanPredictor, ...
    modelFileName, ...
    Ntx, ...
    gruHiddenSize, ...
    gruNumLayers, ...
    initialLearningRate, ...
    trainingBatchSize, ...
    numIterations, ...
    validationFrequency);
  infoFileName = modelFileName+"_info";
  save(infoFileName,"dataMax","dataMin","trainingLoss","validationLoss", ...
    "et","initialLearningRate","trainingBatchSize","numIterations","validationFrequency", ...
    "Ntx","gruHiddenSize","gruNumLayers");
  fprintf("Saved network in '%s' file and\nnetwork info in '%s.mat' file.\n", ...
    string(fileName),infoFileName)
else
  load("channel_predictor_gru_horizon2_iters50000_info.mat","dataMax","dataMin","trainingLoss","validationLoss",...
    "et","initialLearningRate","trainingBatchSize","numIterations","validationFrequency",...
    "Ntx","gruHiddenSize","gruNumLayers");
  finalValidationLoss = validationLoss(end);
  iterations = validationFrequency:validationFrequency:validationFrequency*length(validationLoss);
chanPredictor = py.nr_channel_predictor_wrapper.construct_model(...
Ntx,...
gruHiddenSize,...
gruNumLayers, ...
"channel_predictor_gru_horizon2_iters50000.pth")
end
chanPredictor = 
  Python ChannelPredictorGRU with properties:

    training: 1

    ChannelPredictorGRU(
      (gru): GRU(4, 128, num_layers=2, batch_first=True, dropout=0.3)
      (layer_norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (fc): Linear(in_features=128, out_features=4, bias=True)
    )

fprintf("Validation Loss: %f dB",10*log10(finalValidationLoss))
Validation Loss: -29.457658 dB

Plot the training and validation loss. As the number of iterations increases, the loss value converges to about -30 dB.

figure
plot(1:numIterations,10*log10(trainingLoss));
hold on
plot(iterations,10*log10(validationLoss));
hold off
legend("Training","Validation")
xlabel("Iteration")
ylabel("Loss (dB)")
title("Training Performance (NMSE as Loss)")
grid on

Figure contains an axes object. The axes object with title Training Performance (NMSE as Loss), xlabel Iteration, ylabel Loss (dB) contains 2 objects of type line. These objects represent Training, Validation.

References

[1] W. Jiang and H. D. Schotten, "Recurrent Neural Network-Based Frequency-Domain Channel Prediction for Wideband Communications," 2019 IEEE 89th Vehicular Technology Conference (VTC2019-Spring), Kuala Lumpur, Malaysia, 2019, pp. 1–6, doi: 10.1109/VTCSpring.2019.8746352.

[2] O. Stenhammar,G. Fodor and C. Fischione, "A Comparison of Neural Networks for Wireless Channel Prediction," in IEEE Wireless Communications, vol. 31, no. 3, pp. 235–241, June 2024, doi: 10.1109/MWC.006.2300140.

See Also

Topics