Main Content

trainSOLOV2

Train SOLOv2 network to perform instance segmentation

Since R2023b

    Description

    trainedDetector = trainSOLOV2(trainingData,network,options) trains a SOLOv2 network to perform instance segmentation. A trained SOLOv2 network object can perform instance segmentation to detect and segment multiple object classes. This syntax supports transfer learning on a pretrained SOLOv2 network, as well as training an uninitialized SOLOv2 network.

    Note

    This functionality requires Deep Learning Toolbox™ and the Computer Vision Toolbox™ Model for SOLOv2 Instance Segmentation. You can install the Computer Vision Toolbox Model for SOLOv2 Instance Segmentation from Add-On Explorer. For more information about installing add-ons, see Get and Manage Add-Ons.

    [trainedDetector,info] = trainSOLOV2(trainingData,network,options) also returns information on the training progress, such as the training loss for each iteration.

    [___] = trainSOLOV2(___,Name=Value) specifies network training options using name-value arguments, in addition to any combination of output arguments from previous syntaxes. For example, FreezeSubNetwork="none" specifies not to freeze subnetworks during training.

    Input Arguments

    collapse all

    Labeled ground truth training data, specified as a datastore. You must set up your data so that calling the read and readall functions on the datastore returns a cell array with four columns. This table describes the format of each cell in each column.

    DataBoxesLabelsMasks

    RGB or grayscale image that serves as a network input, specified as an H-by-W-by-3 or H-by-W-by-1 numeric array, respectively.

    Bounding boxes, defined in spatial coordinates as an M-by-4 numeric matrix with rows of the form [x y w h], where:

    • M is the number of axis-aligned rectangles.

    • x and y specify the upper-left corner of the rectangle.

    • w specifies the width of the rectangle, which is its length along the x-axis.

    • h specifies the height of the rectangle, which is its length along the y-axis.

    Object class names, specified as an M-by-1 categorical vector, where M is the number of object instances in the image. All categorical data read from the datastore must contain the same categories.

    Binary masks, specified as a logical array of size H-by-W-by-M, where M is the number of boxes in the image. Each channel is a mask, and each mask is the segmentation of one object instance in the image.

    You can create a datastore that returns data in the required format using these steps:

    1. Create an ImageDatastore that returns RGB or grayscale image data.

    2. Create a boxLabelDatastore that returns bounding box data and instance labels as a two-element cell array.

    3. Create an ImageDatastore and specify a custom read function that returns mask data as a binary matrix.

    4. Combine the three datastores using the combine function.

    For more information, see Get Started with SOLOv2 for Instance Segmentation.

    SOLOv2 instance segmentation network to train, specified as a solov2 object.

    Training options, specified as a TrainingOptionsSGDM, TrainingOptionsRMSProp, or TrainingOptionsADAM object returned by the trainingOptions (Deep Learning Toolbox) function. To specify the solver name and other options for network training, use the trainingOptions function.

    Name-Value Arguments

    Specify optional pairs of arguments as Name1=Value1,...,NameN=ValueN, where Name is the argument name and Value is the corresponding value. Name-value arguments must appear after other arguments, but the order of the pairs does not matter.

    Example: trainedDetector = trainSOLOV2(trainingData,network,options,FreezeSubNetwork="none") specifies not to freeze subnetworks during training.

    Subnetworks to freeze during training, specified as one of these values:

    • "none" — Do not freeze subnetworks.

    • "backbone" — Freeze the feature extraction subnetwork, including the layers following the region of interest (ROI) align layer.

    • "backboneAndNeck" — Freeze the feature extraction subnetwork, as well as the path aggregation network used to mix backbone features at different scales.

    The weight of layers in frozen subnetworks does not change during training.

    Network training experiment monitoring, specified as an experiments.Monitor (Deep Learning Toolbox) object for use with the Experiment Manager (Deep Learning Toolbox) app. You can use this object to track the progress of training, update information fields in the training results table, record values of the metrics used for training, and produce training plots. For more information on using this app, see the Train Object Detectors in Experiment Manager example.

    The app monitors this information during training:

    • Training loss at each iteration

    • Learning rate at each iteration

    When the options input contains validation data, the app also monitors validation loss at each iteration.

    Output Arguments

    collapse all

    Trained SOLOv2 instance segmentation model, returned as a solov2 object.

    Training progress information, returned as a structure with these fields. Each field corresponds to a stage of training.

    • TrainingLoss — Training loss at each iteration. The loss is the combination of the region proposal network (RPN), classification, regression and mask loss used to train the SOLOv2 network.

    • LearnRate — Learning rate at each iteration.

    Each field contains a numeric vector with one element per training iteration. For information that the function does not calculate at a specific iteration, the value is NaN. If solov2 specifies validation data, the info structure also contains a ValidationLoss field.

    Tips

    • For cases in which the network training does not converge, try specifying the GradientThreshold argument when calling the trainingOptions function.

    • When you want to perform transfer learning on a data set with similar content to the COCO data set, freeze the backbone and neck of the network to help the network training converge faster.

    Version History

    Introduced in R2023b