필터 지우기
필터 지우기

How to implement PyTorch's Linear layer in Matlab?

조회 수: 11 (최근 30일)
John Smith
John Smith 2023년 2월 11일
편집: Matt J 2023년 2월 15일
Hello,
How can I implement PyTorch's Linear layer in Matlab?
The problem is that Linear does not flatten its inputs whereas Matlab's fullyConnectedLayer does, so the two are not equivalent.
Thx,
J

답변 (4개)

Matt J
Matt J 2023년 2월 11일
편집: Matt J 2023년 2월 11일
One possibility might be to express the linear layer as a cascade of fullyConnectedLayer followed by a functionLayer. The functionLayer can reshape the flattened input back to the form you want,
layer = functionLayer(@(X)reshape(X,[h,w,c]));
  댓글 수: 9
Matt J
Matt J 2023년 2월 13일
편집: Matt J 2023년 2월 13일
This solution sums all channels together.
No, it won't. (Keep in mind that this is the 3rd solution I've proposed as information about your aims has come out). After the reshaping, each channel is contained in its own column of X. And, because the filter you apply to X is (H*W)x1xN there is no way for the filter to combine elements from different columns.
John Smith
John Smith 2023년 2월 13일
Got it. This looks nice.

댓글을 달려면 로그인하십시오.


John Smith
John Smith 2023년 2월 13일
편집: John Smith 2023년 2월 13일
It is possible to perform matrix multiplication using convolution as described in "Fast algorithms for matrix multiplication using pseudo-number-theoretic transforms" (behind paywall):
  1. Converting the matrix A to a sequence
  2. Converting the matrix B to a sparse sequence
  3. Performing 1d convolution between the two sequences to obtain sequence
  4. Extracting matrix C entries from
Unfortunately, the paper provides only equations for the square matrix case . I worked out the general case . The critical equations are:
For -sequence: , , polynome degree ,
For -sequence: , , polynome degree ,
For -sequence: , , polynome degree ,
Also, need to pay attention to the fact that Matlab requires polynome's coefficients in descending order.
This is how you do convolution with dlconv s.t. it matches conv:
a=(1:4)';
b=(5:10)';
dla=dlarray(a,'SCB');
weights=flipud(b); % dlconv uses reverse order of conv's weights
% filterSize-by-numChannels-by-numFilters array, where
% filterSize is the size of the 1-D filters,
% numChannels is the number of channels of the input data,
% and numFilters is the number of filters.
bias = 0;
dlc=dlconv(dla,weights,bias,'Padding',length(weights)-1);
c = extractdata(dlc);
assert(all(abs(c-conv(a,b)) < 1e-14),'conv is different from dlconv');
Hope this helps.

Matt J
Matt J 2023년 2월 13일
편집: Matt J 2023년 2월 13일
Another possible way to interpret your question is that you are trying to apply pagemtimes to the input X with a non-learnable matrix A, where the different channels of X are the pages. That can also be done with a functionLayer, as illustrated below both with normal arrays and with dlarrays,
A=rand(4,3); %non-learnable matrix A
xdata=rand(3,3,2); %input layer data with 2 channels
multLayer=functionLayer(@(X) dlarray( pagemtimes(A,stripdims(X)) ,dims(X)) );
X=dlarray(xdata,'SSC');
Y=multLayer.predict(X)
Y =
4(S) × 3(S) × 2(C) dlarray (:,:,1) = 0.8480 0.9729 0.9338 1.1000 1.6463 1.5592 0.9130 1.2452 1.1881 1.1243 1.3362 1.2971 (:,:,2) = 0.8228 0.5187 1.1387 1.1783 0.7549 1.5675 0.9390 0.5816 1.2925 1.0862 0.6101 1.6128
%%Verify agreement with normal pagemtimes
ydata=pagemtimes(A,xdata)
ydata =
ydata(:,:,1) = 0.8480 0.9729 0.9338 1.1000 1.6463 1.5592 0.9130 1.2452 1.1881 1.1243 1.3362 1.2971 ydata(:,:,2) = 0.8228 0.5187 1.1387 1.1783 0.7549 1.5675 0.9390 0.5816 1.2925 1.0862 0.6101 1.6128
  댓글 수: 3
Matt J
Matt J 2023년 2월 13일
편집: Matt J 2023년 2월 13일
The modification for the case where A is learnable is as below. I am using a pre-declared A here only so that I can demonstrate and test the response. In a real scenario, you wouldn't supply weights to the convolution2dLayer.
X=dlarray(rand(3,3,2),'SSC'); A=rand(4,3);
[h,w,c]=size(X);
L1=functionLayer( @(z) z(:,:) );
Lconv=convolution2dLayer([h,1],4,'Weights',permute(A,[2,3,4,1]));
L2=functionLayer(@(z)recoverShape(z,w,c) ,'Formattable',1);
net=dlnetwork([L1,Lconv,L2],X);
Yfinal=net.predict(X)
Yfinal =
4(S) × 3(S) × 2(C) single dlarray (:,:,1) = 1.1104 0.9300 0.6060 1.2600 0.9268 0.8284 0.9742 0.8960 0.7068 1.5047 1.0413 0.8057 (:,:,2) = 0.4512 0.3565 0.2455 0.6190 0.5052 0.5721 0.2262 0.4928 0.3368 0.8938 0.4187 0.5370
And as before, we can compare to the result of a plain-vanilla pagemtimes operation and see that it gives the same result:
Ycheck=pagemtimes(A, extractdata(X))
Ycheck =
Ycheck(:,:,1) = 1.1104 0.9300 0.6060 1.2600 0.9268 0.8284 0.9742 0.8960 0.7068 1.5047 1.0413 0.8057 Ycheck(:,:,2) = 0.4512 0.3565 0.2455 0.6190 0.5052 0.5721 0.2262 0.4928 0.3368 0.8938 0.4187 0.5370
function out=recoverShape(z,w,c)
z=permute( stripdims(z), [3,2,1]);
out=dlarray(reshape(z,[],w,c),'SSC');
end
John Smith
John Smith 2023년 2월 14일
편집: John Smith 2023년 2월 14일
Very nice! Just need to add the batch dimension.
I'd suggest to put this in a separate answer s.t. I can accept it.
PS Too bad it's not available in Matlab as a built-in.

댓글을 달려면 로그인하십시오.


Matt J
Matt J 2023년 2월 14일
편집: Matt J 2023년 2월 14일
Another approach is to write your own custom layer for channel-wise matrix multiplication. I have attached a possible version of this,
X=rand(3,3,2);
L=pagemtimesLayer(4); %Custom layer - premultiplies channels by 4-row learnable matrix A
L=initialize(L, X);
Ypred=L.predict(X)
Ypred =
4(S) × 3(S) × 2(C) dlarray (:,:,1) = 0.6102 0.3216 0.8590 0.8080 0.5732 1.3988 0.2763 0.1120 0.2556 0.5463 0.8053 1.2450 (:,:,2) = 0.6860 0.9692 0.6784 1.1580 1.5767 1.1105 0.1999 0.2773 0.1199 1.1686 1.3306 0.5205
Ycheck=pagemtimes(L.A,X) %Check agreement with a direct call to pagemtimes()
Ycheck =
Ycheck(:,:,1) = 0.6102 0.3216 0.8590 0.8080 0.5732 1.3988 0.2763 0.1120 0.2556 0.5463 0.8053 1.2450 Ycheck(:,:,2) = 0.6860 0.9692 0.6784 1.1580 1.5767 1.1105 0.1999 0.2773 0.1199 1.1686 1.3306 0.5205
  댓글 수: 8
John Smith
John Smith 2023년 2월 15일
편집: John Smith 2023년 2월 15일
Presumably X (and Z) has a batch dimension. So, dLdZ should be at least summed over that dimension. Summing over the other dimensions depend on the shape of the bias: if it's a scalar, then all other dimensions should be summed over, if it doesn't depend on the channel dimension, then the channel dimension should also be summed over, etc.
Matt J
Matt J 2023년 2월 15일
That sounds right.
Although, part of me questions whether it was the best design for TMW to make the the user responsible for summing over batched input in the backward() method, since that dimension should always be handled the same way.

댓글을 달려면 로그인하십시오.

카테고리

Help CenterFile Exchange에서 Custom Training Loops에 대해 자세히 알아보기

제품


릴리스

R2022b

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by