import sys import pandas as pd import argparse import base64 from flask import send_file, Response from flask_socketio import emit from piedemo.fields.ajax_group import AjaxChatField, AjaxGroup from piedemo.fields.grid import VStack, HStack, SpaceField from piedemo.fields.inputs.hidden import InputHiddenField from piedemo.fields.outputs.colored_text import ptext, OutputColoredTextField from piedemo.fields.outputs.json import OutputJSONField from piedemo.fields.outputs.progress import ProgressField from piedemo.fields.outputs.video import OutputVideoField from piedemo.web import Web import os import io from piedemo.page import Page from piedemo.hub.svgpil import SVGImage from piedemo.fields.outputs.table import OutputTableField from piedemo.fields.inputs.int_list import InputIntListField from piedemo.fields.navigation import Navigation from piedemo.fields.inputs.chat import ChatField import librosa import uuid import numpy as np import redis import argparse from scipy.signal import savgol_filter import torch import random import os, subprocess import shlex from tqdm import tqdm from aihandler import AIHandler from pieinfer import PieInfer, render_video, construct_video import torch from TTS.api import TTS # Get device device = "cuda" if torch.cuda.is_available() else "cpu" def get_asset(fname): return SVGImage.open(os.path.join(os.path.dirname(__file__), "assets", fname)).svg_content class MainPage(Page): def __init__(self, model_name: str): super(MainPage, self).__init__() self.infer = PieInfer() self.tts = TTS("tts_models/multilingual/multi-dataset/xtts_v2").to(device) self.r = redis.Redis(host='localhost', port=6379, decode_responses=True) self.aihandler = AIHandler() self.fields = Navigation(AjaxGroup("ChatGroup", VStack([ HStack([ AjaxChatField("Chat", self.register_ajax(f"/refresh_{model_name}", self.message_sent), deps_names=["sid", "session_id", "Chat", "Chat__piedemo__file"], use_socketio_support=True, nopie=True, style={ "height": "100%" }), OutputColoredTextField("video", nopie=True, use_socketio_support=True), ], xs=[8, 4]), ProgressField("progress", nopie=True, use_socketio_support=True), InputHiddenField("session_id", None), ]), no_return=True), no_submit=True, page_title="MIA PIA", page_style={ }) self.fields.add_link("SIMPLE", "/simple", active=model_name == "render") self.fields.add_link("MIA PIA", "/nice", active=model_name != "render") self.model_name = model_name def get_content(self, **kwargs): fields = self.fields.copy() fields.child_loc["Chat"].set_default_options(["Hello! What is your name?", "Say one word and stop."]) """ fields.child_loc["Chat"].set_avatars({ "self": get_asset("avatar.svg"), "ChatGPT": get_asset("dog.svg"), }) """ session_id = str(uuid.uuid4()) return self.fill(fields, { "video": f""" """, "session_id": session_id, }) def message_sent(self, **data): sid = data['sid'] self.emit(self.fields.child_loc["Chat"].clear_input(), to=sid) self.emit(self.fields.child_loc["video"].update(f""" """)) data = self.parse(self.fields, data) session_id = data['session_id'] messages_map = self.r.hgetall(f'user-session:{session_id}') messages = [self.fields.child_loc["Chat"].format_message("self" if i % 2 == 0 else "ChatGPT", messages_map[f"message_{i}"]) for i in range(len(messages_map))] print("history: ", messages) text = data['Chat']['text'] self.emit(self.fields.child_loc["Chat"].update(messages + [ self.fields.child_loc["Chat"].format_message("self", text), self.fields.child_loc["Chat"].format_message("ChatGPT", "Generating text..."), ]), to=sid) output = self.aihandler(text) output_text = output['text'] output_emotion = output['emotion'] messages_map[f"message_{len(messages)}"] = text messages_map[f"message_{len(messages) + 1}"] = output_text self.r.hset(f'user-session:{session_id}', mapping=messages_map) self.emit(self.fields.child_loc["Chat"].update(messages + [ self.fields.child_loc["Chat"].format_message("self", text), self.fields.child_loc["Chat"].format_message("ChatGPT", "Generating audio..."), ]), to=sid) self.tts.tts_to_file(text=output_text, speaker_wav="/home/ubuntu/repo/of_couse_here.wav", language="en", emotion=output_emotion, file_path=f"./audio/{session_id}.wav") speech_array, sampling_rate = librosa.load(f"./audio/{session_id}.wav", sr=16000) output = self.infer(speech_array, sampling_rate) np.save(os.path.join("./audio", "{}.npy".format(session_id)), output) self.emit(self.fields.child_loc["Chat"].update(messages + [ self.fields.child_loc["Chat"].format_message("self", text), self.fields.child_loc["Chat"].format_message("ChatGPT", "Rendering..."), ]), to=sid) n = output.shape[0] for i, fname in enumerate(tqdm(render_video(f"{session_id}", model_name=self.model_name), total=n)): print("Got frame: ", fname, file=sys.stderr) self.emit(self.fields.child_loc["progress"].update(100 * i // n), to=sid) construct_video(session_id) self.emit(self.fields.child_loc["video"].update(f""" """), to=sid) '''self.emit(self.fields.child_loc["video"].update(f""" """))''' self.emit(self.fields.child_loc["Chat"].update(messages + [ self.fields.child_loc["Chat"].format_message("self", text), self.fields.child_loc["Chat"].format_message("ChatGPT", output_text), ]), to=sid) web = Web({ "": "simple", "simple": MainPage("render"), "nice": MainPage("FemAdv_b350_V2_050523"), }, use_socketio_support=True) host = '0.0.0.0' port = 8011 debug = False app = web.get_app() @app.route("/api/video/", methods=["GET"]) def get_video(session_id): return send_file("./audio/{}.mp4".format(session_id)) def gen(session_id): for image_path in render_video(f"{session_id}"): with open(image_path, 'rb') as f: yield (b'--frame\r\n' b'Content-Type: image/jpeg\r\n\r\n' + f.read() + b'\r\n') construct_video(session_id) @app.route("/api/video/stream/", methods=["GET"]) def get_video_async(session_id): return Response(gen(session_id), mimetype='multipart/x-mixed-replace; boundary=frame') io = web.get_socketio(app) @io.on("io_set_text") def io_set_text(data): sid = None if "text" not in data: emit("io_error", {"message": "Text not found"}, to=sid) encode_string = base64.b64encode(open("../feeling_good.wav", "rb").read()) for i in range(10): j = random.randint(0, 2) emit("io_set_coef", [{ "index": j, "value": i / 10, }], to=sid) emit("io_push_audio_blob", { "dataURL": f"base64,{encode_string}" }, to=sid) emit("io_finish", {}, to=sid) io.run(app, host=host, port=port, debug=debug, allow_unsafe_werkzeug=True)