File tree Expand file tree Collapse file tree 1 file changed +7
-3
lines changed Expand file tree Collapse file tree 1 file changed +7
-3
lines changed Original file line number Diff line number Diff line change @@ -106,12 +106,16 @@ class NaFlexVitCfg:
106
106
# Image processing
107
107
dynamic_img_pad : bool = False # Whether to enable dynamic padding for variable resolution
108
108
109
- # Architecture choices
109
+ # Other architecture choices
110
110
pre_norm : bool = False # Whether to apply normalization before attention/MLP layers (start of blocks)
111
111
final_norm : bool = True # Whether to apply final normalization before pooling and classifier (end of blocks)
112
112
fc_norm : Optional [bool ] = None # Whether to normalize features before final classifier (after pooling)
113
+
114
+ # Global pooling setup
113
115
global_pool : str = 'map' # Type of global pooling for final sequence
114
116
pool_include_prefix : bool = False # Whether to include class/register prefix tokens in global pooling
117
+ attn_pool_num_heads : Optional [int ] = None # Override num_heads for attention pool
118
+ attn_pool_mlp_ratio : Optional [float ] = None # Override mlp_ratio for attention pool
115
119
116
120
# Weight initialization
117
121
weight_init : str = '' # Weight initialization scheme
@@ -1212,8 +1216,8 @@ def __init__(
1212
1216
if cfg .global_pool == 'map' :
1213
1217
self .attn_pool = AttentionPoolLatent (
1214
1218
self .embed_dim ,
1215
- num_heads = cfg .num_heads ,
1216
- mlp_ratio = cfg .mlp_ratio ,
1219
+ num_heads = cfg .attn_pool_num_heads or cfg . num_heads ,
1220
+ mlp_ratio = cfg .attn_pool_mlp_ratio or cfg . mlp_ratio ,
1217
1221
norm_layer = norm_layer ,
1218
1222
act_layer = act_layer ,
1219
1223
)
You can’t perform that action at this time.
0 commit comments