Main Content

trainDLCHOMP

Train deep-learning-based CHOMP optimizer

Since R2024a

    Description

    info = trainDLCHOMP(dlchomp,trainingData,lossFcn,options) trains the neural network of the deep-learning-based CHOMP optimizer using the training data, loss function, and training options.

    The trainDLCHOMP function requires the Deep Learning Toolbox™.

    Input Arguments

    collapse all

    Deep-learning-based CHOMP optimizer, specified as an dlCHOMP object.

    Training data, specified as a dlCHOMPDatastore object.

    Use the generateSamples function to generate the training datastore.

    Loss function to use for training, specified as one of these values:

    • "crossentropy" — Cross-entropy loss for classification tasks.

    • "binary-crossentropy" — Binary cross-entropy loss for binary and multilabel classification tasks.

    • "mse" / "mean-squared-error" / "l2loss" — Mean squared error for regression tasks.

    • "mae" / "mean-absolute-error" / "l1loss" — Mean absolute error for regression tasks.

    • "huber" — Huber loss for regression tasks

    • Function handle with the syntax loss = f(Y1,...,Yn,T1,...,Tm), where Y1,...,Yn are dlarray objects that correspond to the n network predictions and T1,...,Tm are dlarray objects that correspond to the m targets.

    Tip

    For weighted cross-entropy, use the function handle @(Y,T) crossentropy(Y,T,weights).

    Training options, specified as an trainingOptions (Deep Learning Toolbox) object.

    Output Arguments

    collapse all

    Training information, returned as a TrainingInfo object with these properties:

    • TrainingHistory — Information about training iterations

    • ValidationHistory — Information about validation iterations

    • OutputNetworkIteration — Iteration that corresponds to trained network

    • StopReason — Reason why training stopped

    You can use info to open and close the training progress plot using the show and close functions.

    Version History

    Introduced in R2024a