The transformer starts with predicting random words and then becomes optimized towards predicting “a”, “.”, “\n”, “” for all positions (I changed the batch size to 1, see output below). This makes me think there’s something strange with the loss function. I tried just feeding the loss function random inputs but pytorch has some “grad_fn” field that the model fills in and I assume is used for back propagation. I don’t know how to fill in the field manually.
I’m using:
criterion = nn.CrossEntropyLoss(ignore_index = TRG_PAD_IDX).to(device)
which supposedly skips over pad indexes, this line is from the original notebook.
@Alireza_Saei, do you know if there’s a way to inspect how the loss function is computed when using PyTorch? The documentation has equations, but I’m not sure over which dimensions to compute what.
The input SRC/TRG pairs for training look ok:
SRC: ['<bos>', 'Ein', 'kleines', 'Mädchen', 'mit', 'einem', 'rosa', 'Würfel', 'in', 'ihren', 'braunen', 'Haaren', 'macht', 'ein', 'trauriges', 'Gesicht', '.', '\n', '<eos>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>']
TRG: ['<bos>', 'A', 'young', 'girl', 'with', 'pink', 'dice', 'in', 'her', 'brown', 'hair', 'has', 'a', 'sad', 'look', 'on', 'her', 'face', '.', '\n', '<eos>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>']
This is the output from the transformer after each training input pair:
EPOCH 1, start time 1732144480.1517425
['pocket', 'campground', 'relaxes', 'Times', 'Code', 'Princess', 'ballons', 'Teens', 'referencing', 'snowball', 'from', 'coastline', 'coastline', 'Times', 'Times', 'attentively', 'blacked', 'Teens', 'Times', 'Times', 'batch', 'pocket', 'coastline', 'tastes', 'Times', 'Teens', 'waterskis', 'Times', 'Times', 'Times', 'pocket', 'drawer', 'Toyota', 'pocket', 'pocket', 'Teens', 'coastline', 'snowball', 'pocket', 'pocket', 'tastes', 'knit', 'steel', 'Times', 'Code', 'snowball', 'Family', 'Toyota', 'Tartuffe', 'snowball', 'drawer', 'Times', 'Times', 'Times', 'attentively', 'ashtray', 'nose', 'Times', 'pocket', 'snowball', 'warmly', 'steel', 'ballons', 'perked', 'ha', 'Toyota', 'ballons', 'Code', 'Times', 'corks', 'Mrs.', 'Teens', 'Times', 'warmly', 'Times', 'ha', 'Costco', 'ha', 'Times', 'drawer', 'pocket', 'laughs', 'perked', 'leaking', 'Times', 'camera', 'steel', 'Code', 'Times', 'Times', 'Times', 'Machine', 'Artwork', 'Times', 'attentively', 'waterskis', 'maid', 'Times', 'Times', 'Times']
loss tensor(9.2950, device='cuda:0', grad_fn=<NllLossBackward0>)
['older', 'ethnicity', 'older', 'older', 'with', 'competitive', 'paperwork', 'At', 'shoe', 'older', 'older', 'older', 'competitive', 'darkly', 'Average', 'older', 'shoe', 'projection', 'contently', 'older', 'free', 'formally', 'skill', 'older', 'lesbians', 'older', 'Vegas', 'older', 'older', 'ethnicity', 'older', 'formally', 'older', 'competitive', 'Vegas', 'older', 'ethnicity', 'older', 'older', 'arrival', 'competitive', 'older', 'older', 'lesbians', 'older', 'lesbians', 'knit', 'ethnicity', 'ethnicity', 'guides', 'shoe', 'formally', 'older', 'formally', 'shoe', 'Inside', 'older', 'plucking', 'Derby', 'lesbians', 'older', 'older', 'paperwork', 'older', 'older', 'older', 'lesbians', 'older', 'older', 'competitive', 'pocket', 'paddling', 'projection', 'Derby', 'older', 'jockeys', 'older', 'older', 'ethnicity', 'older', 'older', 'flower', 'dayglo', 'competitive', 'formally', 'capital', 'older', 'shoe', 'older', 'older', 'shoe', 'projection', 'older', 'competitive', 'Machine', 'paddling', 'older', 'Average', 'older', 'older']
loss tensor(9.2938, device='cuda:0', grad_fn=<NllLossBackward0>)
['a', 'a', 'older', 'older', 'projection', 'a', 'older', 'with', 'older', 'a', 'a', 'a', 'a', 'a', 'a', 'older', 'older', 'a', 'a', 'applause', 'older', 'older', 'active', 'a', 'older', 'older', 'bathroom', 'older', 'a', 'Vegas', 'a', 'Carhartt', 'a', 'bathroom', 'older', 'older', 'a', 'a', 'a', '<eos>', 'a', 'right', 'older', 'a', 'older', 'a', 'older', 'a', 'projection', 'older', 'a', 'a', 'older', 'older', 'older', 'shoe', 'older', 'a', 'a', 'a', 'active', 'with', 'a', 'blacked', 'a', 'a', 'a', 'older', 'a', 'projection', 'older', 'older', 'older', 'older', 'a', 'a', 'a', 'a', 'older', 'a', 'a', 'with', 'older', '<eos>', 'a', 'older', 'older', 'a', '<eos>', 'a', 'Derby', 'older', 'a', 'older', 'a', 'a', 'older', 'with', 'with', 'a']
loss tensor(9.1648, device='cuda:0', grad_fn=<NllLossBackward0>)
['a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', '\n', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'has', 'a', 'a', '.', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', '.', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a']
Thanks!