How the number of parameters is calculated if multihead self attention layer is used in a CNN model?

조회 수: 72 (최근 30일)
I have run the example in the following link in two cases:
Case 1: NumHeads = 4, NumKeyChannels = 784 Case 2: NumHeads = 8, NumKeyChannels = 392 Note that:
4x784 = 8x392 = 3136 (size of input feature vector to the attention layer). I have calculated the number of model parameters in the two cases and I got the following: 9.8 M for the first case, and 4.9 M for the second case.
I expected the number of learnable parameters to be the same. However, MATLAB reports different parameter counts.
My understanding from research papers is that the total parameters should not scale with how input is split across heads. The number of parameters should be the same as long as the input feature vector is the same, and the product of the number of heads by the size of each head (number of channels) is equal to the input size.
Why does MATLAB’s selfAttentionLayer produce different parameter counts for these two configurations? Am I misinterpreting how the layer is implemented in this toolbox?
  댓글 수: 11
Hana Ahmed
Hana Ahmed 대략 18시간 전
Thank you very much for your reply. A final question please. If we have 8 parallel heads, each head has three projection matrices, do we expect to see 24 projection matrices in the work space? or only the three matrices of only one head?
Umar
Umar 대략 8시간 전

Hi @Hana Ahmed,

Even though each of the 8 heads conceptually has its own Q/K/V matrices, MATLAB stores them as three concatenated matrices. Each matrix is sliced internally to compute per-head projections, which is why you see only 3 matrices in the workspace instead of 24.

Script

close all; clear all; clc
numHeads = 8; d_k = 64; inputDim = 512; batchSize = 10;
X = randn(batchSize, inputDim);
% Concatenated projection matrices
W_Q = randn(inputDim, numHeads*d_k);
Q_full = X * W_Q;           % [10 x 512]
% Slice per head
Q_heads = zeros(batchSize, numHeads, d_k);
for i = 1:numHeads
  idx = (i-1)*d_k + 1 : i*d_k;
  Q_heads(:, i, :) = Q_full(:, idx);
end
disp(size(Q_full)) 
disp(size(Q_heads)) 

Results:

Explanation:

  • `Q_full` shows all 8 heads concatenated.
  • `Q_heads` shows per-head slices (64 channels each).
  • This is mathematically equivalent to having separate matrices per head and is memory-efficient.

댓글을 달려면 로그인하십시오.

답변 (0개)

카테고리

Help CenterFile Exchange에서 Get Started with Polyspace Products for Ada에 대해 자세히 알아보기

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by