244 lines
8.6 KiB
Python
244 lines
8.6 KiB
Python
import sys
|
|
|
|
import pandas as pd
|
|
import argparse
|
|
import base64
|
|
|
|
from flask import send_file, Response
|
|
from flask_socketio import emit
|
|
from piedemo.fields.ajax_group import AjaxChatField, AjaxGroup
|
|
from piedemo.fields.grid import VStack, HStack, SpaceField
|
|
from piedemo.fields.inputs.hidden import InputHiddenField
|
|
from piedemo.fields.outputs.colored_text import ptext, OutputColoredTextField
|
|
from piedemo.fields.outputs.json import OutputJSONField
|
|
from piedemo.fields.outputs.progress import ProgressField
|
|
from piedemo.fields.outputs.video import OutputVideoField
|
|
from piedemo.web import Web
|
|
import os
|
|
import io
|
|
from piedemo.page import Page
|
|
from piedemo.hub.svgpil import SVGImage
|
|
from piedemo.fields.outputs.table import OutputTableField
|
|
from piedemo.fields.inputs.int_list import InputIntListField
|
|
from piedemo.fields.navigation import Navigation
|
|
from piedemo.fields.inputs.chat import ChatField
|
|
import librosa
|
|
import uuid
|
|
import numpy as np
|
|
import redis
|
|
import argparse
|
|
from scipy.signal import savgol_filter
|
|
import torch
|
|
import random
|
|
import os, subprocess
|
|
import shlex
|
|
|
|
from tqdm import tqdm
|
|
|
|
from aihandler import AIHandler
|
|
from pieinfer import PieInfer, render_video, construct_video
|
|
import torch
|
|
from TTS.api import TTS
|
|
|
|
# Get device
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
|
|
def get_asset(fname):
|
|
return SVGImage.open(os.path.join(os.path.dirname(__file__),
|
|
"assets",
|
|
fname)).svg_content
|
|
|
|
|
|
class MainPage(Page):
|
|
def __init__(self, model_name: str):
|
|
super(MainPage, self).__init__()
|
|
self.infer = PieInfer()
|
|
self.tts = TTS("tts_models/multilingual/multi-dataset/xtts_v2").to(device)
|
|
|
|
self.r = redis.Redis(host='localhost', port=6379, decode_responses=True)
|
|
self.aihandler = AIHandler()
|
|
|
|
self.fields = Navigation(AjaxGroup("ChatGroup", VStack([
|
|
HStack([
|
|
AjaxChatField("Chat",
|
|
self.register_ajax(f"/refresh_{model_name}",
|
|
self.message_sent),
|
|
deps_names=["sid",
|
|
"session_id",
|
|
"Chat",
|
|
"Chat__piedemo__file"],
|
|
use_socketio_support=True,
|
|
nopie=True,
|
|
style={
|
|
"height": "100%"
|
|
}),
|
|
OutputColoredTextField("video",
|
|
nopie=True,
|
|
use_socketio_support=True),
|
|
], xs=[8, 4]),
|
|
ProgressField("progress",
|
|
nopie=True,
|
|
use_socketio_support=True),
|
|
InputHiddenField("session_id", None),
|
|
]), no_return=True), no_submit=True, page_title="MIA PIA", page_style={
|
|
|
|
})
|
|
self.fields.add_link("SIMPLE",
|
|
"/simple",
|
|
active=model_name == "render")
|
|
self.fields.add_link("MIA PIA",
|
|
"/nice",
|
|
active=model_name != "render")
|
|
self.model_name = model_name
|
|
|
|
def get_content(self, **kwargs):
|
|
fields = self.fields.copy()
|
|
fields.child_loc["Chat"].set_default_options(["Hello! What is your name?", "Say one word and stop."])
|
|
"""
|
|
fields.child_loc["Chat"].set_avatars({
|
|
"self": get_asset("avatar.svg"),
|
|
"ChatGPT": get_asset("dog.svg"),
|
|
})
|
|
"""
|
|
session_id = str(uuid.uuid4())
|
|
return self.fill(fields, {
|
|
"video": f"""
|
|
""",
|
|
"session_id": session_id,
|
|
})
|
|
|
|
def message_sent(self, **data):
|
|
sid = data['sid']
|
|
self.emit(self.fields.child_loc["Chat"].clear_input(),
|
|
to=sid)
|
|
self.emit(self.fields.child_loc["video"].update(f"""
|
|
"""))
|
|
data = self.parse(self.fields, data)
|
|
session_id = data['session_id']
|
|
messages_map = self.r.hgetall(f'user-session:{session_id}')
|
|
messages = [self.fields.child_loc["Chat"].format_message("self" if i % 2 == 0 else "ChatGPT",
|
|
messages_map[f"message_{i}"])
|
|
for i in range(len(messages_map))]
|
|
|
|
print("history: ", messages)
|
|
|
|
text = data['Chat']['text']
|
|
|
|
self.emit(self.fields.child_loc["Chat"].update(messages + [
|
|
self.fields.child_loc["Chat"].format_message("self", text),
|
|
self.fields.child_loc["Chat"].format_message("ChatGPT", "Generating text..."),
|
|
]), to=sid)
|
|
|
|
output = self.aihandler(text)
|
|
output_text = output['text']
|
|
output_emotion = output['emotion']
|
|
|
|
messages_map[f"message_{len(messages)}"] = text
|
|
messages_map[f"message_{len(messages) + 1}"] = output_text
|
|
self.r.hset(f'user-session:{session_id}', mapping=messages_map)
|
|
|
|
self.emit(self.fields.child_loc["Chat"].update(messages + [
|
|
self.fields.child_loc["Chat"].format_message("self", text),
|
|
self.fields.child_loc["Chat"].format_message("ChatGPT", "Generating audio..."),
|
|
]), to=sid)
|
|
|
|
self.tts.tts_to_file(text=output_text,
|
|
speaker_wav="/home/ubuntu/repo/of_couse_here.wav",
|
|
language="en",
|
|
emotion=output_emotion,
|
|
file_path=f"./audio/{session_id}.wav")
|
|
speech_array, sampling_rate = librosa.load(f"./audio/{session_id}.wav",
|
|
sr=16000)
|
|
output = self.infer(speech_array, sampling_rate)
|
|
np.save(os.path.join("./audio", "{}.npy".format(session_id)),
|
|
output)
|
|
|
|
self.emit(self.fields.child_loc["Chat"].update(messages + [
|
|
self.fields.child_loc["Chat"].format_message("self", text),
|
|
self.fields.child_loc["Chat"].format_message("ChatGPT", "Rendering..."),
|
|
]), to=sid)
|
|
|
|
n = output.shape[0]
|
|
for i, fname in enumerate(tqdm(render_video(f"{session_id}",
|
|
model_name=self.model_name),
|
|
total=n)):
|
|
print("Got frame: ", fname, file=sys.stderr)
|
|
self.emit(self.fields.child_loc["progress"].update(100 * i // n),
|
|
to=sid)
|
|
construct_video(session_id)
|
|
|
|
self.emit(self.fields.child_loc["video"].update(f"""
|
|
<video controls="1" autoplay="1" name="media" style="border-radius: 12px; height: 80%">
|
|
<source src="/api/video/{session_id}" type="video/mp4">
|
|
</video>
|
|
"""), to=sid)
|
|
|
|
'''self.emit(self.fields.child_loc["video"].update(f"""
|
|
<img name="media" style="border-radius: 12px; height: 80%" src="/api/video/stream/{session_id}"></img>
|
|
"""))'''
|
|
self.emit(self.fields.child_loc["Chat"].update(messages + [
|
|
self.fields.child_loc["Chat"].format_message("self", text),
|
|
self.fields.child_loc["Chat"].format_message("ChatGPT", output_text),
|
|
]), to=sid)
|
|
|
|
|
|
web = Web({
|
|
"": "simple",
|
|
"simple": MainPage("render"),
|
|
"nice": MainPage("FemAdv_b350_V2_050523"),
|
|
}, use_socketio_support=True)
|
|
|
|
|
|
host = '0.0.0.0'
|
|
port = 8011
|
|
debug = False
|
|
app = web.get_app()
|
|
|
|
|
|
@app.route("/api/video/<session_id>", methods=["GET"])
|
|
def get_video(session_id):
|
|
return send_file("./audio/{}.mp4".format(session_id))
|
|
|
|
|
|
def gen(session_id):
|
|
for image_path in render_video(f"{session_id}"):
|
|
with open(image_path, 'rb') as f:
|
|
yield (b'--frame\r\n'
|
|
b'Content-Type: image/jpeg\r\n\r\n' + f.read() + b'\r\n')
|
|
construct_video(session_id)
|
|
|
|
|
|
@app.route("/api/video/stream/<session_id>", methods=["GET"])
|
|
def get_video_async(session_id):
|
|
return Response(gen(session_id),
|
|
mimetype='multipart/x-mixed-replace; boundary=frame')
|
|
|
|
|
|
io = web.get_socketio(app)
|
|
|
|
|
|
@io.on("io_set_text")
|
|
def io_set_text(data):
|
|
sid = None
|
|
if "text" not in data:
|
|
emit("io_error", {"message": "Text not found"},
|
|
to=sid)
|
|
|
|
encode_string = base64.b64encode(open("../feeling_good.wav", "rb").read())
|
|
for i in range(10):
|
|
j = random.randint(0, 2)
|
|
emit("io_set_coef", [{
|
|
"index": j,
|
|
"value": i / 10,
|
|
}], to=sid)
|
|
emit("io_push_audio_blob", {
|
|
"dataURL": f"base64,{encode_string}"
|
|
}, to=sid)
|
|
emit("io_finish", {}, to=sid)
|
|
|
|
|
|
io.run(app,
|
|
host=host, port=port, debug=debug,
|
|
allow_unsafe_werkzeug=True)
|