이 질문을 팔로우합니다.
- 팔로우하는 게시물 피드에서 업데이트를 확인할 수 있습니다.
- 정보 수신 기본 설정에 따라 이메일을 받을 수 있습니다.
I'm using VIT transformer in my code. How to convert the output of 1D layer of VIT into 2D with format SSCB?
댓글 수: 8
Hi Abdulrahman,
I cannot execute the code because visionTransformer requires Computer Vision Toolbox. To illustrate resolving your error, I had to adapt your given code from mathworks for input dimensions of 24 x 24 x 768, by adjusting the reshaping and processing steps accordingly. Here is update the code step by step:
% Get Vision Transformer model
net = visionTransformer;
% Create dummy input
input = dlarray(rand(24,24,768),'SSCB');
% Obtain output embedding from the last LayerNormalizationLayer
out = forward(net, input, 'Outputs', 'encoder_norm');
% Reshape output patch embedding
out = reshapePatchEmbedding(out);
function out = reshapePatchEmbedding(in)
% Remove output embedding corresponding to the class token from the input
out = in(2:end,:,:);
% Reshape the resulting embedding to the input format
WH = sqrt(size(out, 1));
C = size(out, 2);
out = reshape(out, WH, WH, C, []); % Shape is W x H x C x N
out = permute(out, [2, 1, 3, 4]); % Shape is H x W x C x N
% Convert to formatted dlarray
out = dlarray(out, 'SSCB');
end
So, in my updated code snippet, I changed the dummy input dimensions to 24 x 24 x 768 to match the specified input size. The reshaping function reshapePatchEmbedding has been adjusted to handle the new dimensions correctly. Please let me know if this helps resolve your issue.
답변 (2개)
댓글 수: 1
Hi Abdulrahman,
I cannot execute the code because visionTransformer requires Computer Vision Toolbox. To illustrate resolving your error, I had to adapt your given code from mathworks for input dimensions of 24 x 24 x 768, by adjusting the reshaping and processing steps accordingly. Here is update the code step by step:
% Get Vision Transformer model
net = visionTransformer;
% Create dummy input
input = dlarray(rand(24,24,768),'SSCB');
% Obtain output embedding from the last LayerNormalizationLayer
out = forward(net, input, 'Outputs', 'encoder_norm');
% Reshape output patch embedding
out = reshapePatchEmbedding(out);
function out = reshapePatchEmbedding(in)
% Remove output embedding corresponding to the class token from the input
out = in(2:end,:,:);
% Reshape the resulting embedding to the input format
WH = sqrt(size(out, 1));
C = size(out, 2);
out = reshape(out, WH, WH, C, []); % Shape is W x H x C x N
out = permute(out, [2, 1, 3, 4]); % Shape is H x W x C x N
% Convert to formatted dlarray
out = dlarray(out, 'SSCB');
end
So, in my updated code snippet, I changed the dummy input dimensions to 24 x 24 x 768 to match the specified input size. The reshaping function reshapePatchEmbedding has been adjusted to handle the new dimensions correctly. Please let me know if this helps resolve your issue.
참고 항목
카테고리
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!오류 발생
페이지가 변경되었기 때문에 동작을 완료할 수 없습니다. 업데이트된 상태를 보려면 페이지를 다시 불러오십시오.
웹사이트 선택
번역된 콘텐츠를 보고 지역별 이벤트와 혜택을 살펴보려면 웹사이트를 선택하십시오. 현재 계신 지역에 따라 다음 웹사이트를 권장합니다:
또한 다음 목록에서 웹사이트를 선택하실 수도 있습니다.
사이트 성능 최적화 방법
최고의 사이트 성능을 위해 중국 사이트(중국어 또는 영어)를 선택하십시오. 현재 계신 지역에서는 다른 국가의 MathWorks 사이트 방문이 최적화되지 않았습니다.
미주
- América Latina (Español)
- Canada (English)
- United States (English)
유럽
- Belgium (English)
- Denmark (English)
- Deutschland (Deutsch)
- España (Español)
- Finland (English)
- France (Français)
- Ireland (English)
- Italia (Italiano)
- Luxembourg (English)
- Netherlands (English)
- Norway (English)
- Österreich (Deutsch)
- Portugal (English)
- Sweden (English)
- Switzerland
- United Kingdom(English)
아시아 태평양
- Australia (English)
- India (English)
- New Zealand (English)
- 中国
- 日本Japanese (日本語)
- 한국Korean (한국어)