40 lines
1.5 KiB
Python
40 lines
1.5 KiB
Python
# Borrowed from https://github.com/EvelynFan/FaceFormer/blob/main/faceformer.py
|
|
import torch
|
|
import math
|
|
|
|
|
|
# Temporal Bias
|
|
def init_biased_mask(n_head, max_seq_len, period):
|
|
def get_slopes(n):
|
|
def get_slopes_power_of_2(n):
|
|
start = (2 ** (-2 ** -(math.log2(n) - 3)))
|
|
ratio = start
|
|
return [start * ratio ** i for i in range(n)]
|
|
|
|
if math.log2(n).is_integer():
|
|
return get_slopes_power_of_2(n)
|
|
else:
|
|
closest_power_of_2 = 2 ** math.floor(math.log2(n))
|
|
return get_slopes_power_of_2(closest_power_of_2) + get_slopes(2 * closest_power_of_2)[0::2][
|
|
:n - closest_power_of_2]
|
|
|
|
slopes = torch.Tensor(get_slopes(n_head))
|
|
bias = torch.arange(start=0, end=max_seq_len, step=period).unsqueeze(1).repeat(1, period).view(-1) // (period)
|
|
bias = - torch.flip(bias, dims=[0])
|
|
alibi = torch.zeros(max_seq_len, max_seq_len)
|
|
for i in range(max_seq_len):
|
|
alibi[i, :i + 1] = bias[-(i + 1):]
|
|
alibi = slopes.unsqueeze(1).unsqueeze(1) * alibi.unsqueeze(0)
|
|
mask = (torch.triu(torch.ones(max_seq_len, max_seq_len)) == 1).transpose(0, 1)
|
|
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
|
|
mask = mask.unsqueeze(0) + alibi
|
|
return mask
|
|
|
|
|
|
# Alignment Bias
|
|
def enc_dec_mask(device, T, S):
|
|
mask = torch.ones(T, S).to(device)
|
|
for i in range(T):
|
|
mask[i, i] = 0
|
|
return (mask == 1).to(device=device)
|