first commit
This commit is contained in:
39
emotalk_own/utils.py
Normal file
39
emotalk_own/utils.py
Normal file
@@ -0,0 +1,39 @@
|
||||
# Borrowed from https://github.com/EvelynFan/FaceFormer/blob/main/faceformer.py
|
||||
import torch
|
||||
import math
|
||||
|
||||
|
||||
# Temporal Bias
|
||||
def init_biased_mask(n_head, max_seq_len, period):
|
||||
def get_slopes(n):
|
||||
def get_slopes_power_of_2(n):
|
||||
start = (2 ** (-2 ** -(math.log2(n) - 3)))
|
||||
ratio = start
|
||||
return [start * ratio ** i for i in range(n)]
|
||||
|
||||
if math.log2(n).is_integer():
|
||||
return get_slopes_power_of_2(n)
|
||||
else:
|
||||
closest_power_of_2 = 2 ** math.floor(math.log2(n))
|
||||
return get_slopes_power_of_2(closest_power_of_2) + get_slopes(2 * closest_power_of_2)[0::2][
|
||||
:n - closest_power_of_2]
|
||||
|
||||
slopes = torch.Tensor(get_slopes(n_head))
|
||||
bias = torch.arange(start=0, end=max_seq_len, step=period).unsqueeze(1).repeat(1, period).view(-1) // (period)
|
||||
bias = - torch.flip(bias, dims=[0])
|
||||
alibi = torch.zeros(max_seq_len, max_seq_len)
|
||||
for i in range(max_seq_len):
|
||||
alibi[i, :i + 1] = bias[-(i + 1):]
|
||||
alibi = slopes.unsqueeze(1).unsqueeze(1) * alibi.unsqueeze(0)
|
||||
mask = (torch.triu(torch.ones(max_seq_len, max_seq_len)) == 1).transpose(0, 1)
|
||||
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
|
||||
mask = mask.unsqueeze(0) + alibi
|
||||
return mask
|
||||
|
||||
|
||||
# Alignment Bias
|
||||
def enc_dec_mask(device, T, S):
|
||||
mask = torch.ones(T, S).to(device)
|
||||
for i in range(T):
|
||||
mask[i, i] = 0
|
||||
return (mask == 1).to(device=device)
|
||||
Reference in New Issue
Block a user