Skip to content

Commit c603c31

Browse files
committed
Fix attn pool specific num_heads / mlp_ratio being passed for PE models in NaFlexVit
1 parent 6fb3536 commit c603c31

File tree

1 file changed

+7
-3
lines changed

1 file changed

+7
-3
lines changed

timm/models/naflexvit.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -106,12 +106,16 @@ class NaFlexVitCfg:
106106
# Image processing
107107
dynamic_img_pad: bool = False # Whether to enable dynamic padding for variable resolution
108108

109-
# Architecture choices
109+
# Other architecture choices
110110
pre_norm: bool = False # Whether to apply normalization before attention/MLP layers (start of blocks)
111111
final_norm: bool = True # Whether to apply final normalization before pooling and classifier (end of blocks)
112112
fc_norm: Optional[bool] = None # Whether to normalize features before final classifier (after pooling)
113+
114+
# Global pooling setup
113115
global_pool: str = 'map' # Type of global pooling for final sequence
114116
pool_include_prefix: bool = False # Whether to include class/register prefix tokens in global pooling
117+
attn_pool_num_heads: Optional[int] = None # Override num_heads for attention pool
118+
attn_pool_mlp_ratio: Optional[float] = None # Override mlp_ratio for attention pool
115119

116120
# Weight initialization
117121
weight_init: str = '' # Weight initialization scheme
@@ -1212,8 +1216,8 @@ def __init__(
12121216
if cfg.global_pool == 'map':
12131217
self.attn_pool = AttentionPoolLatent(
12141218
self.embed_dim,
1215-
num_heads=cfg.num_heads,
1216-
mlp_ratio=cfg.mlp_ratio,
1219+
num_heads=cfg.attn_pool_num_heads or cfg.num_heads,
1220+
mlp_ratio=cfg.attn_pool_mlp_ratio or cfg.mlp_ratio,
12171221
norm_layer=norm_layer,
12181222
act_layer=act_layer,
12191223
)

0 commit comments

Comments
 (0)