Training a multitask network using custom training loops with dlnetwork()

조회 수: 9 (최근 30일)
Julius Å
Julius Å 2020년 8월 26일
댓글: Sven 2023년 5월 3일
I want to train a multitask-network, i.e. a neural network that has one input and gives two outputs from separate output layers that are connected according to the image below. Each output should be associated with its own loss function, i.e. OUTPUT1 should be the input to LOSS1 and OUTPUT2 should be the input to LOSS2. The total loss function for backpropagation should be TOTAL_LOSS = LOSS1 + LOSS2. As can be noted in the image below, warnings are given for the final layers. This is because I have not specified the loss functions when creating these layers, and they are thus not output layers. The reason for not doing this is that the only example I have found for training multiple output networks with a combined loss function (https://se.mathworks.com/help/deeplearning/ug/train-network-with-multiple-outputs.html) does not use the dlnetwork()-function to create the network, but instead uses a model function, in which all operations are carried out. The problem is that my network is a modified U-Net, and in order to create a model function for this, I would have to implement all layer functions, as for example dropout, from scratch. This would be very tedious.
Is there any way to train a network with multiple outputs and combined loss functions as explained above using the dlnetwork()-function, without having to create a model function from scratch?
Note: I posed a similar, but less direct question a few months ago regarding how to train a multitask network in MATLAB, with no answer:
I then decided to move over to Keras and Python for this problem, with great results. Now I want to know if similar results can be obtained using MATLAB.
  댓글 수: 1
Sven
Sven 2023년 5월 3일
Julius, did you ever succeed in what you were trying to do? We are in the same boat: able to create a network with two separate outputs in Python, would love to do the same thing in MATLAB but just hitting the same brick wall you have described.

댓글을 달려면 로그인하십시오.

답변 (1개)

Nomit Jangid
Nomit Jangid 2020년 11월 30일
Hi,
The documentation you are talking about is using dlnetwork function. I am not sure if the article has been updated since you last visited.
Have a look at the following article.

카테고리

Help CenterFile Exchange에서 Deep Learning Toolbox에 대해 자세히 알아보기

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by