145 lines
8.6 KiB
Python
145 lines
8.6 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
import numpy as np
|
|
import math
|
|
from transformers import Wav2Vec2Processor, Wav2Vec2FeatureExtractor
|
|
from wav2vec import Wav2Vec2Model, Wav2Vec2ForSpeechClassification
|
|
from utils import init_biased_mask, enc_dec_mask
|
|
|
|
|
|
class EmoTalk(nn.Module):
|
|
def __init__(self, args):
|
|
super(EmoTalk, self).__init__()
|
|
self.feature_dim = args.feature_dim
|
|
self.bs_dim = args.bs_dim
|
|
self.device = args.device
|
|
self.batch_size = args.batch_size
|
|
self.audio_encoder_cont = Wav2Vec2Model.from_pretrained("jonatasgrosman/wav2vec2-large-xlsr-53-english")
|
|
self.processor = Wav2Vec2Processor.from_pretrained("jonatasgrosman/wav2vec2-large-xlsr-53-english")
|
|
self.audio_encoder_cont.feature_extractor._freeze_parameters()
|
|
self.audio_encoder_emo = Wav2Vec2ForSpeechClassification.from_pretrained(
|
|
"r-f/wav2vec-english-speech-emotion-recognition")
|
|
self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
|
|
"r-f/wav2vec-english-speech-emotion-recognition")
|
|
self.audio_encoder_emo.wav2vec2.feature_extractor._freeze_parameters()
|
|
self.max_seq_len = args.max_seq_len
|
|
self.audio_feature_map_cont = nn.Linear(1024, 512)
|
|
self.audio_feature_map_emo = nn.Linear(1024, 832)
|
|
self.audio_feature_map_emo2 = nn.Linear(832, 256)
|
|
self.relu = nn.ReLU()
|
|
self.biased_mask1 = init_biased_mask(n_head=4, max_seq_len=args.max_seq_len, period=args.period)
|
|
self.one_hot_level = np.eye(2)
|
|
self.obj_vector_level = nn.Linear(2, 32)
|
|
self.one_hot_person = np.eye(24)
|
|
self.obj_vector_person = nn.Linear(24, 32)
|
|
decoder_layer = nn.TransformerDecoderLayer(d_model=args.feature_dim, nhead=4, dim_feedforward=args.feature_dim,
|
|
batch_first=True)
|
|
self.transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=1)
|
|
self.bs_map_r = nn.Linear(self.feature_dim, self.bs_dim)
|
|
nn.init.constant_(self.bs_map_r.weight, 0)
|
|
nn.init.constant_(self.bs_map_r.bias, 0)
|
|
|
|
def forward(self, data):
|
|
frame_num11 = data["target11"].shape[1]
|
|
frame_num12 = data["target12"].shape[1]
|
|
inputs12 = self.processor(torch.squeeze(data["input12"]), sampling_rate=16000, return_tensors="pt",
|
|
padding="longest").input_values.to(self.device)
|
|
hidden_states_cont1 = self.audio_encoder_cont(inputs12, frame_num=frame_num11).last_hidden_state
|
|
hidden_states_cont12 = self.audio_encoder_cont(inputs12, frame_num=frame_num12).last_hidden_state
|
|
inputs21 = self.feature_extractor(torch.squeeze(data["input21"]), sampling_rate=16000, padding=True,
|
|
return_tensors="pt").input_values.to(self.device)
|
|
inputs12 = self.feature_extractor(torch.squeeze(data["input12"]), sampling_rate=16000, padding=True,
|
|
return_tensors="pt").input_values.to(self.device)
|
|
|
|
output_emo1 = self.audio_encoder_emo(inputs21, frame_num=frame_num11)
|
|
output_emo2 = self.audio_encoder_emo(inputs12, frame_num=frame_num12)
|
|
|
|
hidden_states_emo1 = output_emo1.hidden_states
|
|
hidden_states_emo2 = output_emo2.hidden_states
|
|
|
|
label1 = output_emo1.logits
|
|
onehot_level = self.one_hot_level[data["level"]]
|
|
onehot_level = torch.from_numpy(onehot_level).to(self.device).float()
|
|
onehot_person = self.one_hot_person[data["person"]]
|
|
onehot_person = torch.from_numpy(onehot_person).to(self.device).float()
|
|
if data["target11"].shape[0] == 1:
|
|
obj_embedding_person = self.obj_vector_person(onehot_person).unsqueeze(0)
|
|
obj_embedding_level = self.obj_vector_level(onehot_level).unsqueeze(0)
|
|
else:
|
|
obj_embedding_level = self.obj_vector_level(onehot_level).unsqueeze(0).permute(1, 0, 2)
|
|
obj_embedding_person = self.obj_vector_person(onehot_person).unsqueeze(0).permute(1, 0, 2)
|
|
|
|
obj_embedding_level11 = obj_embedding_level.repeat(1, frame_num11, 1)
|
|
obj_embedding_level12 = obj_embedding_level.repeat(1, frame_num12, 1)
|
|
obj_embedding_person11 = obj_embedding_person.repeat(1, frame_num11, 1)
|
|
obj_embedding_person12 = obj_embedding_person.repeat(1, frame_num12, 1)
|
|
hidden_states_cont1 = self.audio_feature_map_cont(hidden_states_cont1)
|
|
hidden_states_emo11_832 = self.audio_feature_map_emo(hidden_states_emo1)
|
|
hidden_states_emo11_256 = self.relu(self.audio_feature_map_emo2(hidden_states_emo11_832))
|
|
|
|
hidden_states11 = torch.cat(
|
|
[hidden_states_cont1, hidden_states_emo11_256, obj_embedding_level11, obj_embedding_person11], dim=2)
|
|
hidden_states_cont12 = self.audio_feature_map_cont(hidden_states_cont12)
|
|
hidden_states_emo12_832 = self.audio_feature_map_emo(hidden_states_emo2)
|
|
hidden_states_emo12_256 = self.relu(self.audio_feature_map_emo2(hidden_states_emo12_832))
|
|
|
|
hidden_states12 = torch.cat(
|
|
[hidden_states_cont12, hidden_states_emo12_256, obj_embedding_level12, obj_embedding_person12], dim=2)
|
|
if data["target11"].shape[0] == 1:
|
|
tgt_mask11 = self.biased_mask1[:, :hidden_states11.shape[1], :hidden_states11.shape[1]].clone().detach().to(
|
|
device=self.device)
|
|
tgt_mask22 = self.biased_mask1[:, :hidden_states12.shape[1], :hidden_states12.shape[1]].clone().detach().to(
|
|
device=self.device)
|
|
|
|
memory_mask11 = enc_dec_mask(self.device, hidden_states11.shape[1], hidden_states11.shape[1])
|
|
memory_mask12 = enc_dec_mask(self.device, hidden_states12.shape[1], hidden_states12.shape[1])
|
|
bs_out11 = self.transformer_decoder(hidden_states11, hidden_states_emo11_832, tgt_mask=tgt_mask11,
|
|
memory_mask=memory_mask11)
|
|
bs_out12 = self.transformer_decoder(hidden_states12, hidden_states_emo12_832, tgt_mask=tgt_mask22,
|
|
memory_mask=memory_mask12)
|
|
bs_output11 = self.bs_map_r(bs_out11)
|
|
bs_output12 = self.bs_map_r(bs_out12)
|
|
|
|
return bs_output11, bs_output12, label1
|
|
|
|
def predict(self, audio, level, person):
|
|
frame_num11 = math.ceil(audio.shape[1] / 16000 * 30)
|
|
inputs12 = self.processor(torch.squeeze(audio), sampling_rate=16000, return_tensors="pt",
|
|
padding="longest").input_values.to(self.device)
|
|
hidden_states_cont1 = self.audio_encoder_cont(inputs12, frame_num=frame_num11).last_hidden_state
|
|
inputs12 = self.feature_extractor(torch.squeeze(audio), sampling_rate=16000, padding=True,
|
|
return_tensors="pt").input_values.to(self.device)
|
|
output_emo1 = self.audio_encoder_emo(inputs12, frame_num=frame_num11)
|
|
hidden_states_emo1 = output_emo1.hidden_states
|
|
|
|
onehot_level = self.one_hot_level[level]
|
|
onehot_level = torch.from_numpy(onehot_level).to(self.device).float()
|
|
onehot_person = self.one_hot_person[person]
|
|
onehot_person = torch.from_numpy(onehot_person).to(self.device).float()
|
|
if audio.shape[0] == 1:
|
|
obj_embedding_person = self.obj_vector_person(onehot_person).unsqueeze(0)
|
|
obj_embedding_level = self.obj_vector_level(onehot_level).unsqueeze(0)
|
|
else:
|
|
obj_embedding_level = self.obj_vector_level(onehot_level).unsqueeze(0).permute(1, 0, 2)
|
|
obj_embedding_person = self.obj_vector_person(onehot_person).unsqueeze(0).permute(1, 0, 2)
|
|
|
|
obj_embedding_level11 = obj_embedding_level.repeat(1, frame_num11, 1)
|
|
obj_embedding_person11 = obj_embedding_person.repeat(1, frame_num11, 1)
|
|
hidden_states_cont1 = self.audio_feature_map_cont(hidden_states_cont1)
|
|
hidden_states_emo11_832 = self.audio_feature_map_emo(hidden_states_emo1)
|
|
hidden_states_emo11_256 = self.relu(
|
|
self.audio_feature_map_emo2(hidden_states_emo11_832))
|
|
|
|
hidden_states11 = torch.cat(
|
|
[hidden_states_cont1, hidden_states_emo11_256, obj_embedding_level11, obj_embedding_person11], dim=2)
|
|
if audio.shape[0] == 1:
|
|
tgt_mask11 = self.biased_mask1[:, :hidden_states11.shape[1],
|
|
:hidden_states11.shape[1]].clone().detach().to(device=self.device)
|
|
|
|
memory_mask11 = enc_dec_mask(self.device, hidden_states11.shape[1], hidden_states11.shape[1])
|
|
bs_out11 = self.transformer_decoder(hidden_states11, hidden_states_emo11_832, tgt_mask=tgt_mask11,
|
|
memory_mask=memory_mask11)
|
|
bs_output11 = self.bs_map_r(bs_out11)
|
|
|
|
return bs_output11
|