Files
miapia-backend/miapia_own/main.py
George Kasparyants 3633aa99e5 initial commit
2024-04-24 06:57:30 +04:00

244 lines
8.6 KiB
Python

import sys
import pandas as pd
import argparse
import base64
from flask import send_file, Response
from flask_socketio import emit
from piedemo.fields.ajax_group import AjaxChatField, AjaxGroup
from piedemo.fields.grid import VStack, HStack, SpaceField
from piedemo.fields.inputs.hidden import InputHiddenField
from piedemo.fields.outputs.colored_text import ptext, OutputColoredTextField
from piedemo.fields.outputs.json import OutputJSONField
from piedemo.fields.outputs.progress import ProgressField
from piedemo.fields.outputs.video import OutputVideoField
from piedemo.web import Web
import os
import io
from piedemo.page import Page
from piedemo.hub.svgpil import SVGImage
from piedemo.fields.outputs.table import OutputTableField
from piedemo.fields.inputs.int_list import InputIntListField
from piedemo.fields.navigation import Navigation
from piedemo.fields.inputs.chat import ChatField
import librosa
import uuid
import numpy as np
import redis
import argparse
from scipy.signal import savgol_filter
import torch
import random
import os, subprocess
import shlex
from tqdm import tqdm
from aihandler import AIHandler
from pieinfer import PieInfer, render_video, construct_video
import torch
from TTS.api import TTS
# Get device
device = "cuda" if torch.cuda.is_available() else "cpu"
def get_asset(fname):
return SVGImage.open(os.path.join(os.path.dirname(__file__),
"assets",
fname)).svg_content
class MainPage(Page):
def __init__(self, model_name: str):
super(MainPage, self).__init__()
self.infer = PieInfer()
self.tts = TTS("tts_models/multilingual/multi-dataset/xtts_v2").to(device)
self.r = redis.Redis(host='localhost', port=6379, decode_responses=True)
self.aihandler = AIHandler()
self.fields = Navigation(AjaxGroup("ChatGroup", VStack([
HStack([
AjaxChatField("Chat",
self.register_ajax(f"/refresh_{model_name}",
self.message_sent),
deps_names=["sid",
"session_id",
"Chat",
"Chat__piedemo__file"],
use_socketio_support=True,
nopie=True,
style={
"height": "100%"
}),
OutputColoredTextField("video",
nopie=True,
use_socketio_support=True),
], xs=[8, 4]),
ProgressField("progress",
nopie=True,
use_socketio_support=True),
InputHiddenField("session_id", None),
]), no_return=True), no_submit=True, page_title="MIA PIA", page_style={
})
self.fields.add_link("SIMPLE",
"/simple",
active=model_name == "render")
self.fields.add_link("MIA PIA",
"/nice",
active=model_name != "render")
self.model_name = model_name
def get_content(self, **kwargs):
fields = self.fields.copy()
fields.child_loc["Chat"].set_default_options(["Hello! What is your name?", "Say one word and stop."])
"""
fields.child_loc["Chat"].set_avatars({
"self": get_asset("avatar.svg"),
"ChatGPT": get_asset("dog.svg"),
})
"""
session_id = str(uuid.uuid4())
return self.fill(fields, {
"video": f"""
""",
"session_id": session_id,
})
def message_sent(self, **data):
sid = data['sid']
self.emit(self.fields.child_loc["Chat"].clear_input(),
to=sid)
self.emit(self.fields.child_loc["video"].update(f"""
"""))
data = self.parse(self.fields, data)
session_id = data['session_id']
messages_map = self.r.hgetall(f'user-session:{session_id}')
messages = [self.fields.child_loc["Chat"].format_message("self" if i % 2 == 0 else "ChatGPT",
messages_map[f"message_{i}"])
for i in range(len(messages_map))]
print("history: ", messages)
text = data['Chat']['text']
self.emit(self.fields.child_loc["Chat"].update(messages + [
self.fields.child_loc["Chat"].format_message("self", text),
self.fields.child_loc["Chat"].format_message("ChatGPT", "Generating text..."),
]), to=sid)
output = self.aihandler(text)
output_text = output['text']
output_emotion = output['emotion']
messages_map[f"message_{len(messages)}"] = text
messages_map[f"message_{len(messages) + 1}"] = output_text
self.r.hset(f'user-session:{session_id}', mapping=messages_map)
self.emit(self.fields.child_loc["Chat"].update(messages + [
self.fields.child_loc["Chat"].format_message("self", text),
self.fields.child_loc["Chat"].format_message("ChatGPT", "Generating audio..."),
]), to=sid)
self.tts.tts_to_file(text=output_text,
speaker_wav="/home/ubuntu/repo/of_couse_here.wav",
language="en",
emotion=output_emotion,
file_path=f"./audio/{session_id}.wav")
speech_array, sampling_rate = librosa.load(f"./audio/{session_id}.wav",
sr=16000)
output = self.infer(speech_array, sampling_rate)
np.save(os.path.join("./audio", "{}.npy".format(session_id)),
output)
self.emit(self.fields.child_loc["Chat"].update(messages + [
self.fields.child_loc["Chat"].format_message("self", text),
self.fields.child_loc["Chat"].format_message("ChatGPT", "Rendering..."),
]), to=sid)
n = output.shape[0]
for i, fname in enumerate(tqdm(render_video(f"{session_id}",
model_name=self.model_name),
total=n)):
print("Got frame: ", fname, file=sys.stderr)
self.emit(self.fields.child_loc["progress"].update(100 * i // n),
to=sid)
construct_video(session_id)
self.emit(self.fields.child_loc["video"].update(f"""
<video controls="1" autoplay="1" name="media" style="border-radius: 12px; height: 80%">
<source src="/api/video/{session_id}" type="video/mp4">
</video>
"""), to=sid)
'''self.emit(self.fields.child_loc["video"].update(f"""
<img name="media" style="border-radius: 12px; height: 80%" src="/api/video/stream/{session_id}"></img>
"""))'''
self.emit(self.fields.child_loc["Chat"].update(messages + [
self.fields.child_loc["Chat"].format_message("self", text),
self.fields.child_loc["Chat"].format_message("ChatGPT", output_text),
]), to=sid)
web = Web({
"": "simple",
"simple": MainPage("render"),
"nice": MainPage("FemAdv_b350_V2_050523"),
}, use_socketio_support=True)
host = '0.0.0.0'
port = 8011
debug = False
app = web.get_app()
@app.route("/api/video/<session_id>", methods=["GET"])
def get_video(session_id):
return send_file("./audio/{}.mp4".format(session_id))
def gen(session_id):
for image_path in render_video(f"{session_id}"):
with open(image_path, 'rb') as f:
yield (b'--frame\r\n'
b'Content-Type: image/jpeg\r\n\r\n' + f.read() + b'\r\n')
construct_video(session_id)
@app.route("/api/video/stream/<session_id>", methods=["GET"])
def get_video_async(session_id):
return Response(gen(session_id),
mimetype='multipart/x-mixed-replace; boundary=frame')
io = web.get_socketio(app)
@io.on("io_set_text")
def io_set_text(data):
sid = None
if "text" not in data:
emit("io_error", {"message": "Text not found"},
to=sid)
encode_string = base64.b64encode(open("../feeling_good.wav", "rb").read())
for i in range(10):
j = random.randint(0, 2)
emit("io_set_coef", [{
"index": j,
"value": i / 10,
}], to=sid)
emit("io_push_audio_blob", {
"dataURL": f"base64,{encode_string}"
}, to=sid)
emit("io_finish", {}, to=sid)
io.run(app,
host=host, port=port, debug=debug,
allow_unsafe_werkzeug=True)