Skip to content

Commit c67605c

Browse files
committed
add weight clip to self-adversarial negative sampling
1 parent ec55cce commit c67605c

File tree

2 files changed

+16
-4
lines changed

2 files changed

+16
-4
lines changed

include/instance/gpu/knowledge_graph.cuh

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -103,8 +103,12 @@ __global__ void train(Memory<Vector, Index> head_embeddings, Memory<Vector, Inde
103103
sample_loss += weight * -log(prob + kEpsilon);
104104
} else {
105105
gradient = prob;
106-
if (adversarial_temperature > kEpsilon)
106+
if (adversarial_temperature > kEpsilon) {
107107
weight = safe_exp((logit - bias) / adversarial_temperature) / normalizer;
108+
// the normalizer may be out of date in ASGD
109+
// so we need to clip the weight
110+
weight = min(weight, Float(1));
111+
}
108112
else
109113
weight = 1.0 / num_negative;
110114
sample_loss += weight * -log(1 - prob + kEpsilon);
@@ -198,8 +202,12 @@ __global__ void train_1_moment(Memory<Vector, Index> head_embeddings, Memory<Vec
198202
sample_loss += weight * -log(prob + kEpsilon);
199203
} else {
200204
gradient = prob;
201-
if (adversarial_temperature > kEpsilon)
205+
if (adversarial_temperature > kEpsilon) {
202206
weight = safe_exp((logit - bias) / adversarial_temperature) / normalizer;
207+
// the normalizer may be out of date in ASGD
208+
// so we need to clip the weight
209+
weight = min(weight, Float(1));
210+
}
203211
else
204212
weight = 1.0 / num_negative;
205213
sample_loss += weight * -log(1 - prob + kEpsilon);
@@ -298,8 +306,12 @@ __global__ void train_2_moment(Memory<Vector, Index> head_embeddings, Memory<Vec
298306
sample_loss += weight * -log(prob + kEpsilon);
299307
} else {
300308
gradient = prob;
301-
if (adversarial_temperature > kEpsilon)
309+
if (adversarial_temperature > kEpsilon) {
302310
weight = safe_exp((logit - bias) / adversarial_temperature) / normalizer;
311+
// the normalizer may be out of date in ASGD
312+
// so we need to clip the weight
313+
weight = min(weight, Float(1));
314+
}
303315
else
304316
weight = 1.0 / num_negative;
305317
sample_loss += weight * -log(1 - prob + kEpsilon);

python/graphvite/application/application.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -883,7 +883,7 @@ def get_batch_size(self, sample_size):
883883
mem_per_sample = sample_size * (2 * 3 * np.uint32().itemsize + 1 * np.uint64().itemsize)
884884
max_batch_size = int(memory.available / mem_per_sample / self.MEMORY_SCALE_FACTOR)
885885
if max_batch_size < batch_size:
886-
logger.info("Memory is not enough for optimal prediction batch size."
886+
logger.info("Memory is not enough for optimal prediction batch size. "
887887
"Use the maximal possible size instead.")
888888
batch_size = max_batch_size
889889
return batch_size

0 commit comments

Comments
 (0)