First you shouldn't rotate the values, only keys and queries.
This is wrong : v_out = (torch.bmm(v.transpose(0,1), self.R[:m, ...])).transpose(0,1)
Second you shouldn't apply multihead attention which as additional inner weights that will mess with the rotations you have just done.
This is wrong : activations, attn_weights = self.multihead (q_out,k_out,v_out)
Instead you should use scaled_dot_product_attention( q_out,k_out,v_out)
Third, each attention head should have been treated similarly, and each attention head should have the same rotation frequencies.
> Second you shouldn't apply multihead attention which as additional inner weights that will mess with the rotations you have just done
wait does that mean that rotary embeddings don't work with multiheaded attention? First I have heard of this. Wouldn't this be an issue with position embeddings as well (for example sinusoidal position embeddings are a special case of rotary embeddings)?
Afaiu, the whole idea behind rotary embeddings is kind of a hack to switch the similarity metric (that compares query to keys) inside the scaled_dotproduct_attention without having to rewrite the optimized code of scaled_dotproduct_attention.
This custom similarity metric has some properties engineered into it, mainly some invariance with relative positioning, and learnable decay with increasing distance (keys-query similarity decrease with increasing distance in position space and the network can learn how important is position distance compared to feature-space distance). It's a strong prior that works well when relative positioning is important.
It's a refinement of the traditional attention : It's a different and more ambitious aim than what sinusoidal position are trying to do, which is just provide some position information to the neural network so that it can distinguish keys and let it learn what it sees fit.
Sinusoidal position embeddings can learn some relative positioning quite easily because of trigonometry, but they have to learn it. Rotary embeddings have relative positioning baked in : everything is relative to the query position (quite similar point of view as a convolutional network), and the only thing they learn is how important small position distance compared to high position distance should be.