@@ -103,8 +103,12 @@ __global__ void train(Memory<Vector, Index> head_embeddings, Memory<Vector, Inde
103
103
sample_loss += weight * -log (prob + kEpsilon );
104
104
} else {
105
105
gradient = prob;
106
- if (adversarial_temperature > kEpsilon )
106
+ if (adversarial_temperature > kEpsilon ) {
107
107
weight = safe_exp ((logit - bias) / adversarial_temperature) / normalizer;
108
+ // the normalizer may be out of date in ASGD
109
+ // so we need to clip the weight
110
+ weight = min (weight, Float (1 ));
111
+ }
108
112
else
109
113
weight = 1.0 / num_negative;
110
114
sample_loss += weight * -log (1 - prob + kEpsilon );
@@ -198,8 +202,12 @@ __global__ void train_1_moment(Memory<Vector, Index> head_embeddings, Memory<Vec
198
202
sample_loss += weight * -log (prob + kEpsilon );
199
203
} else {
200
204
gradient = prob;
201
- if (adversarial_temperature > kEpsilon )
205
+ if (adversarial_temperature > kEpsilon ) {
202
206
weight = safe_exp ((logit - bias) / adversarial_temperature) / normalizer;
207
+ // the normalizer may be out of date in ASGD
208
+ // so we need to clip the weight
209
+ weight = min (weight, Float (1 ));
210
+ }
203
211
else
204
212
weight = 1.0 / num_negative;
205
213
sample_loss += weight * -log (1 - prob + kEpsilon );
@@ -298,8 +306,12 @@ __global__ void train_2_moment(Memory<Vector, Index> head_embeddings, Memory<Vec
298
306
sample_loss += weight * -log (prob + kEpsilon );
299
307
} else {
300
308
gradient = prob;
301
- if (adversarial_temperature > kEpsilon )
309
+ if (adversarial_temperature > kEpsilon ) {
302
310
weight = safe_exp ((logit - bias) / adversarial_temperature) / normalizer;
311
+ // the normalizer may be out of date in ASGD
312
+ // so we need to clip the weight
313
+ weight = min (weight, Float (1 ));
314
+ }
303
315
else
304
316
weight = 1.0 / num_negative;
305
317
sample_loss += weight * -log (1 - prob + kEpsilon );
0 commit comments