Skip to content

why attend over the <end> token? #3

@homelifes

Description

@homelifes

Hi @sgrvinod
in the xe train function:

predicted_sequences = model(source_sequences, target_sequences, source_sequence_lengths, target_sequence_lengths) # (N, max_target_sequence_pad_length_this_batch, vocab_size)

The target_sequence_lengths still includes the lengths with the <end> token, and in this case in MultiHead Attention it will be attending over the <end> token.

I think it should be: target_sequence_lengths - 1
predicted_sequences = model(source_sequences, target_sequences, source_sequence_lengths, target_sequence_lengths - 1) # (N, max_target_sequence_pad_length_this_batch, vocab_size)

Please clarify

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions