Files
George Kasparyants 7591784e34 first commit
2024-06-14 00:47:32 +03:00

145 lines
8.6 KiB
Python

import torch
import torch.nn as nn
import numpy as np
import math
from transformers import Wav2Vec2Processor, Wav2Vec2FeatureExtractor
from wav2vec import Wav2Vec2Model, Wav2Vec2ForSpeechClassification
from utils import init_biased_mask, enc_dec_mask
class EmoTalk(nn.Module):
def __init__(self, args):
super(EmoTalk, self).__init__()
self.feature_dim = args.feature_dim
self.bs_dim = args.bs_dim
self.device = args.device
self.batch_size = args.batch_size
self.audio_encoder_cont = Wav2Vec2Model.from_pretrained("jonatasgrosman/wav2vec2-large-xlsr-53-english")
self.processor = Wav2Vec2Processor.from_pretrained("jonatasgrosman/wav2vec2-large-xlsr-53-english")
self.audio_encoder_cont.feature_extractor._freeze_parameters()
self.audio_encoder_emo = Wav2Vec2ForSpeechClassification.from_pretrained(
"r-f/wav2vec-english-speech-emotion-recognition")
self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
"r-f/wav2vec-english-speech-emotion-recognition")
self.audio_encoder_emo.wav2vec2.feature_extractor._freeze_parameters()
self.max_seq_len = args.max_seq_len
self.audio_feature_map_cont = nn.Linear(1024, 512)
self.audio_feature_map_emo = nn.Linear(1024, 832)
self.audio_feature_map_emo2 = nn.Linear(832, 256)
self.relu = nn.ReLU()
self.biased_mask1 = init_biased_mask(n_head=4, max_seq_len=args.max_seq_len, period=args.period)
self.one_hot_level = np.eye(2)
self.obj_vector_level = nn.Linear(2, 32)
self.one_hot_person = np.eye(24)
self.obj_vector_person = nn.Linear(24, 32)
decoder_layer = nn.TransformerDecoderLayer(d_model=args.feature_dim, nhead=4, dim_feedforward=args.feature_dim,
batch_first=True)
self.transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=1)
self.bs_map_r = nn.Linear(self.feature_dim, self.bs_dim)
nn.init.constant_(self.bs_map_r.weight, 0)
nn.init.constant_(self.bs_map_r.bias, 0)
def forward(self, data):
frame_num11 = data["target11"].shape[1]
frame_num12 = data["target12"].shape[1]
inputs12 = self.processor(torch.squeeze(data["input12"]), sampling_rate=16000, return_tensors="pt",
padding="longest").input_values.to(self.device)
hidden_states_cont1 = self.audio_encoder_cont(inputs12, frame_num=frame_num11).last_hidden_state
hidden_states_cont12 = self.audio_encoder_cont(inputs12, frame_num=frame_num12).last_hidden_state
inputs21 = self.feature_extractor(torch.squeeze(data["input21"]), sampling_rate=16000, padding=True,
return_tensors="pt").input_values.to(self.device)
inputs12 = self.feature_extractor(torch.squeeze(data["input12"]), sampling_rate=16000, padding=True,
return_tensors="pt").input_values.to(self.device)
output_emo1 = self.audio_encoder_emo(inputs21, frame_num=frame_num11)
output_emo2 = self.audio_encoder_emo(inputs12, frame_num=frame_num12)
hidden_states_emo1 = output_emo1.hidden_states
hidden_states_emo2 = output_emo2.hidden_states
label1 = output_emo1.logits
onehot_level = self.one_hot_level[data["level"]]
onehot_level = torch.from_numpy(onehot_level).to(self.device).float()
onehot_person = self.one_hot_person[data["person"]]
onehot_person = torch.from_numpy(onehot_person).to(self.device).float()
if data["target11"].shape[0] == 1:
obj_embedding_person = self.obj_vector_person(onehot_person).unsqueeze(0)
obj_embedding_level = self.obj_vector_level(onehot_level).unsqueeze(0)
else:
obj_embedding_level = self.obj_vector_level(onehot_level).unsqueeze(0).permute(1, 0, 2)
obj_embedding_person = self.obj_vector_person(onehot_person).unsqueeze(0).permute(1, 0, 2)
obj_embedding_level11 = obj_embedding_level.repeat(1, frame_num11, 1)
obj_embedding_level12 = obj_embedding_level.repeat(1, frame_num12, 1)
obj_embedding_person11 = obj_embedding_person.repeat(1, frame_num11, 1)
obj_embedding_person12 = obj_embedding_person.repeat(1, frame_num12, 1)
hidden_states_cont1 = self.audio_feature_map_cont(hidden_states_cont1)
hidden_states_emo11_832 = self.audio_feature_map_emo(hidden_states_emo1)
hidden_states_emo11_256 = self.relu(self.audio_feature_map_emo2(hidden_states_emo11_832))
hidden_states11 = torch.cat(
[hidden_states_cont1, hidden_states_emo11_256, obj_embedding_level11, obj_embedding_person11], dim=2)
hidden_states_cont12 = self.audio_feature_map_cont(hidden_states_cont12)
hidden_states_emo12_832 = self.audio_feature_map_emo(hidden_states_emo2)
hidden_states_emo12_256 = self.relu(self.audio_feature_map_emo2(hidden_states_emo12_832))
hidden_states12 = torch.cat(
[hidden_states_cont12, hidden_states_emo12_256, obj_embedding_level12, obj_embedding_person12], dim=2)
if data["target11"].shape[0] == 1:
tgt_mask11 = self.biased_mask1[:, :hidden_states11.shape[1], :hidden_states11.shape[1]].clone().detach().to(
device=self.device)
tgt_mask22 = self.biased_mask1[:, :hidden_states12.shape[1], :hidden_states12.shape[1]].clone().detach().to(
device=self.device)
memory_mask11 = enc_dec_mask(self.device, hidden_states11.shape[1], hidden_states11.shape[1])
memory_mask12 = enc_dec_mask(self.device, hidden_states12.shape[1], hidden_states12.shape[1])
bs_out11 = self.transformer_decoder(hidden_states11, hidden_states_emo11_832, tgt_mask=tgt_mask11,
memory_mask=memory_mask11)
bs_out12 = self.transformer_decoder(hidden_states12, hidden_states_emo12_832, tgt_mask=tgt_mask22,
memory_mask=memory_mask12)
bs_output11 = self.bs_map_r(bs_out11)
bs_output12 = self.bs_map_r(bs_out12)
return bs_output11, bs_output12, label1
def predict(self, audio, level, person):
frame_num11 = math.ceil(audio.shape[1] / 16000 * 30)
inputs12 = self.processor(torch.squeeze(audio), sampling_rate=16000, return_tensors="pt",
padding="longest").input_values.to(self.device)
hidden_states_cont1 = self.audio_encoder_cont(inputs12, frame_num=frame_num11).last_hidden_state
inputs12 = self.feature_extractor(torch.squeeze(audio), sampling_rate=16000, padding=True,
return_tensors="pt").input_values.to(self.device)
output_emo1 = self.audio_encoder_emo(inputs12, frame_num=frame_num11)
hidden_states_emo1 = output_emo1.hidden_states
onehot_level = self.one_hot_level[level]
onehot_level = torch.from_numpy(onehot_level).to(self.device).float()
onehot_person = self.one_hot_person[person]
onehot_person = torch.from_numpy(onehot_person).to(self.device).float()
if audio.shape[0] == 1:
obj_embedding_person = self.obj_vector_person(onehot_person).unsqueeze(0)
obj_embedding_level = self.obj_vector_level(onehot_level).unsqueeze(0)
else:
obj_embedding_level = self.obj_vector_level(onehot_level).unsqueeze(0).permute(1, 0, 2)
obj_embedding_person = self.obj_vector_person(onehot_person).unsqueeze(0).permute(1, 0, 2)
obj_embedding_level11 = obj_embedding_level.repeat(1, frame_num11, 1)
obj_embedding_person11 = obj_embedding_person.repeat(1, frame_num11, 1)
hidden_states_cont1 = self.audio_feature_map_cont(hidden_states_cont1)
hidden_states_emo11_832 = self.audio_feature_map_emo(hidden_states_emo1)
hidden_states_emo11_256 = self.relu(
self.audio_feature_map_emo2(hidden_states_emo11_832))
hidden_states11 = torch.cat(
[hidden_states_cont1, hidden_states_emo11_256, obj_embedding_level11, obj_embedding_person11], dim=2)
if audio.shape[0] == 1:
tgt_mask11 = self.biased_mask1[:, :hidden_states11.shape[1],
:hidden_states11.shape[1]].clone().detach().to(device=self.device)
memory_mask11 = enc_dec_mask(self.device, hidden_states11.shape[1], hidden_states11.shape[1])
bs_out11 = self.transformer_decoder(hidden_states11, hidden_states_emo11_832, tgt_mask=tgt_mask11,
memory_mask=memory_mask11)
bs_output11 = self.bs_map_r(bs_out11)
return bs_output11