# Borrowed from https://github.com/EvelynFan/FaceFormer/blob/main/faceformer.py import torch import math # Temporal Bias def init_biased_mask(n_head, max_seq_len, period): def get_slopes(n): def get_slopes_power_of_2(n): start = (2 ** (-2 ** -(math.log2(n) - 3))) ratio = start return [start * ratio ** i for i in range(n)] if math.log2(n).is_integer(): return get_slopes_power_of_2(n) else: closest_power_of_2 = 2 ** math.floor(math.log2(n)) return get_slopes_power_of_2(closest_power_of_2) + get_slopes(2 * closest_power_of_2)[0::2][ :n - closest_power_of_2] slopes = torch.Tensor(get_slopes(n_head)) bias = torch.arange(start=0, end=max_seq_len, step=period).unsqueeze(1).repeat(1, period).view(-1) // (period) bias = - torch.flip(bias, dims=[0]) alibi = torch.zeros(max_seq_len, max_seq_len) for i in range(max_seq_len): alibi[i, :i + 1] = bias[-(i + 1):] alibi = slopes.unsqueeze(1).unsqueeze(1) * alibi.unsqueeze(0) mask = (torch.triu(torch.ones(max_seq_len, max_seq_len)) == 1).transpose(0, 1) mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0)) mask = mask.unsqueeze(0) + alibi return mask # Alignment Bias def enc_dec_mask(device, T, S): mask = torch.ones(T, S).to(device) for i in range(T): mask[i, i] = 0 return (mask == 1).to(device=device)