Initializing LSTM which is imported using ONNX

8 visualizaciones (últimos 30 días)
Andreas
Andreas el 17 de Jul. de 2024
Respondida: Andreas el 23 de Jul. de 2024
Hi,
I am training an LSTM for RL using Ray in Python. I would like to export this model using ONNX and afterwards import it in Matlab. As far as I have understood, I need to initialize the model in matlab after importing. However, I cannot find out the correct input shapes/formats in Matlab to make this work.
Minimum working example:
Python code to train LSTM:
import torch
import numpy as np
from ray.rllib.algorithms.ppo import PPOConfig
% Config Algorithm
algo = (
PPOConfig()
.env_runners(num_env_runners=1)
.resources(num_gpus=0)
.environment(env="CartPole-v1")
.training(model={"use_lstm": True})
.build()
)
% train for 2 episodes
for i in range(2):
result = algo.train()
% get policiy
ppo_policy = algo.get_policy()
% batch size
B=1
% initialize LSTM input:
input_dict = {"obs": torch.tensor(np.random.uniform(0, 1.0, size=(B,4)).astype(np.float32))}
state_batches = [torch.zeros((B,256), dtype=torch.float32),torch.zeros((B,256), dtype=torch.float32)]
seq_lens = torch.ones([B], dtype=int)
% apply LSTM to inputs
policy = algo.get_policy()
model = policy.model
print(model(input_dict, state=state_batches, seq_lens=seq_lens))
% save model to ONNX
ppo_policy.export_model('onnx14', onnx=14)
Code in Matlab:
% Import model from where I saved it
net = importNetworkFromONNX('path/to/onnx-model');
% input shapes
obs_size = [1,4];
state_size=[2,1,256];
seq_lens_size=[1];
% initialize input arrays
obs = dlarray(rand(obs_size),"BS");
state = dlarray(rand(state_size),"SBS");
seq_len = dlarray(rand(seq_lens_size),"SB");
% initialize net
net = initialize(net,obs,state,seq_len);
Error message:
I appreciate any help!
Best,
Andreas
  2 comentarios
Nilesh
Nilesh el 17 de Jul. de 2024
Editada: Nilesh el 17 de Jul. de 2024
Hello Andreas,
Have you tried asking your issue with ChatGPT.
Andreas
Andreas el 17 de Jul. de 2024
Yes, but without success so far.

Iniciar sesión para comentar.

Respuestas (3)

Joss Knight
Joss Knight el 18 de Jul. de 2024
This code is suspect
% initialize input arrays
obs = dlarray(rand(obs_size),"BS");
state = dlarray(rand(state_size),"SBS");
seq_len = dlarray(rand(seq_lens_size),"SB");
% initialize net
net = initialize(net,obs,state,seq_len);
I think your network has a single input, so you need to pass a single input to initialize (along with the network), basically just some example input exactly like you want to pass to predict. I think you have two channels and a sequence length of 256? And one of your dimensions is Time so you need a T dimension. And I don't think you have any spatial dimensions, so no S labels. So you need something like
exampleInput = dlarray(rand(2,1,256),'CBT');
net = initialize(net, exampleInput);
Or if you prefer, a permutation of that like
exampleInput = dlarray(rand(256,2,1),'TCB');
net = initialize(net, exampleInput);
If this doesn't work, try running analyzeNetwork(net) to see where your inputs are and we can work out what to expect.
  1 comentario
Andreas
Andreas el 23 de Jul. de 2024
Hi,
the network does not have a single input. I managed to solve the issue, see below for my response. Thank you, for your help anyway!

Iniciar sesión para comentar.


Kaustab Pal
Kaustab Pal el 19 de Jul. de 2024
It seems you want to determine the input dimension of your imported network. You can easily find this information using the analyzeNetwork function. This function provides an interactive visualization of the network architecture and detailed information, including:
  • Layer types
  • Sizes and formats of layer learnable parameters
  • States and activations
  • Total number of learnable parameters
The activation size of the topmost layer will give you the input dimension.
Additionally, when creating dlarray objects in MATLAB, you need to specify the format, which must follow this order:
  • "S" (Spatial)
  • "C" (Channel)
  • "B" (Batch)
  • "T" (Time)
  • "U" (Unspecified)
For more details, you can refer to the following links:
  1. analyzeNetwork Documentation: https://www.mathworks.com/help/deeplearning/ref/analyzenetwork.html#mw_bdd24886-fa03-4540-a111-391541a0a684
  2. dlarray Documentation:: https://www.mathworks.com/help/deeplearning/ref/dlarray.html#d126e57736:~:text=When%20you%20create%20a%20formatted%20dlarray%20object%2C%20the%20software%20automatically%20permutes%20the%20dimensions%20such%20that%20the%20format%20has%20dimensions%20in%20this%20order%3A
Hope this helps.
  1 comentario
Joss Knight
Joss Knight el 19 de Jul. de 2024
Just FYI, the formats do not have to follow that order.

Iniciar sesión para comentar.


Andreas
Andreas el 23 de Jul. de 2024
Helly everyone,
thank you for your help. Unfortunately, I had to work around this issue but I could solve it in the end. I believe the reason for matlab struggling is that within Ray's Rllib the models contain a lot of complicated overhead. In particular the inputs to the network are lists/dicts etc which undergo quite some reformatting which seemed to cause some issues. In the end, what I did is extract the actual torch models which are relevant from the trained Rllib object and joined them in a new torch.nn.Module object. For this object it worked out just fine using torch.onnx.export.
Thank you all for your help.
Best, Andreas

Categorías

Más información sobre Sequence and Numeric Feature Data Workflows en Help Center y File Exchange.

Productos


Versión

R2024a

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by