How the number of parameters is calculated if multihead self attention layer is used in a CNN model?

36 visualizaciones (últimos 30 días)
I have run the example in the following link in two cases:
Case 1: NumHeads = 4, NumKeyChannels = 784 Case 2: NumHeads = 8, NumKeyChannels = 392 Note that:
4x784 = 8x392 = 3136 (size of input feature vector to the attention layer). I have calculated the number of model parameters in the two cases and I got the following: 9.8 M for the first case, and 4.9 M for the second case.
I expected the number of learnable parameters to be the same. However, MATLAB reports different parameter counts.
My understanding from research papers is that the total parameters should not scale with how input is split across heads. The number of parameters should be the same as long as the input feature vector is the same, and the product of the number of heads by the size of each head (number of channels) is equal to the input size.
Why does MATLAB’s selfAttentionLayer produce different parameter counts for these two configurations? Am I misinterpreting how the layer is implemented in this toolbox?
  3 comentarios
Hana Ahmed
Hana Ahmed el 29 de Ag. de 2025 a las 3:44
I would be very grateful if you could suggest or provide a correct MATLAB implementation — ideally as a custom layer — that follows the standard multi-head attention equations.
Umar
Umar el 30 de Ag. de 2025 a las 9:40

Hi @Hana Ahmed,

Thanks for your follow-up! I think writing the multi-head attention mechanism from scratch would be a great way to get the transparency and control you're looking for. It will also help you understand the underlying principles better.

Here’s a quick skeleton of the pseudo code to guide your implementation:

Skeleton of Pseudo Code:

function Y = multiHeadAttention(X, numHeads, keyChannels)
  % X: Input matrix [batchSize, inputDim]
  % numHeads: Number of attention heads
  % keyChannels: Dimensionality per head
    [batchSize, inputDim] = size(X);
    d_k = keyChannels; % Dimension per head
    % Define weights for Q, K, V, and output projection
    W_Q = randn(inputDim, numHeads * d_k);
    W_K = randn(inputDim, numHeads * d_k);
    W_V = randn(inputDim, numHeads * d_k);
    W_O = randn(numHeads * d_k, inputDim);
    % Compute Q, K, V
    Q = X * W_Q;
    K = X * W_K;
    V = X * W_V;
    % Reshape for multiple heads
    Q = reshape(Q, batchSize, numHeads, d_k);
    K = reshape(K, batchSize, numHeads, d_k);
    V = reshape(V, batchSize, numHeads, d_k);
    % Compute attention for each head
    attentionOutput = zeros(batchSize, numHeads, d_k);
    for i = 1:numHeads
        % Compute scaled dot-product attention for each head
        attentionScores = Q(:, i, :) * K(:, i, :)' / sqrt(d_k);
        attentionWeights = softmax(attentionScores, 2);
        attentionOutput(:, i, :) = attentionWeights * V(:, i, :);
    end
    % Concatenate heads and project to output
    attentionOutput = reshape(attentionOutput, batchSize, numHeads * d_k);
    Y = attentionOutput * W_O;
  end

I suggest you try implementing this yourself in MATLAB, following the structure above. This will give you a hands-on understanding of how the attention mechanism works.

If you run into any issues or get stuck, feel free to reach out, and I’d be happy to help debug.

Good luck with the implementation!

Iniciar sesión para comentar.

Respuestas (0)

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by