-
Notifications
You must be signed in to change notification settings - Fork 52
Open
Description
Hi @sgrvinod
in the xe train function:
predicted_sequences = model(source_sequences, target_sequences, source_sequence_lengths, target_sequence_lengths) # (N, max_target_sequence_pad_length_this_batch, vocab_size)
The target_sequence_lengths
still includes the lengths with the <end>
token, and in this case in MultiHead Attention it will be attending over the <end>
token.
I think it should be: target_sequence_lengths - 1
predicted_sequences = model(source_sequences, target_sequences, source_sequence_lengths, target_sequence_lengths - 1) # (N, max_target_sequence_pad_length_this_batch, vocab_size)
Please clarify
Metadata
Metadata
Assignees
Labels
No labels