Hey thanks for the great work.
I believe the masking for AKT in line 273 is:
scores.masked_fill(mask == 0, -1e23)
But the inplace masking to make sure there is no data leakage from the future should be:
scores.masked_fill_(mask == 0, -1e23)
Happy to open a PR if you could please grant access.