Skip to content

Commit 9625dc2

Browse files
committed
Shorten comment lines for apply_keep_indices
1 parent 6063f82 commit 9625dc2

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

timm/layers/pos_embed_sincos.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -241,10 +241,11 @@ def apply_keep_indices_nlc(
241241
pos_embed_has_batch: bool = False,
242242
) -> torch.Tensor:
243243
""" Apply keep indices to different ROPE shapes
244-
Expected shapes:
245-
* pos_embed shape [seq_len, pos_embed_dim] → output [batch_size, seq_len, pos_embed_dim]
246-
* pos_embed shape [num_heads, seq_len, pos_embed_dim] → output [batch_size, num_heads, seq_len, pos_embed_dim]
247-
* pos_embed shape [depth, num_heads, seq_len, pos_embed_dim] → output [batch_size, depth, num_heads, seq_len, pos_embed_dim]
244+
245+
Expected pos_embed shapes:
246+
* [seq_len, pos_embed_dim] --> output [batch_size, seq_len, pos_embed_dim]
247+
* [num_heads, seq_len, pos_embed_dim] --> output [batch_size, num_heads, seq_len, pos_embed_dim]
248+
* [depth, num_heads, seq_len, pos_embed_dim] --> output [batch_size, depth, num_heads, seq_len, pos_embed_dim]
248249
249250
And all of the above with leading batch dimension already present if `pos_embed_has_batch == True`
250251

0 commit comments

Comments
 (0)