initial commit
This commit is contained in:
243
miapia_own/main.py
Normal file
243
miapia_own/main.py
Normal file
@@ -0,0 +1,243 @@
|
||||
import sys
|
||||
|
||||
import pandas as pd
|
||||
import argparse
|
||||
import base64
|
||||
|
||||
from flask import send_file, Response
|
||||
from flask_socketio import emit
|
||||
from piedemo.fields.ajax_group import AjaxChatField, AjaxGroup
|
||||
from piedemo.fields.grid import VStack, HStack, SpaceField
|
||||
from piedemo.fields.inputs.hidden import InputHiddenField
|
||||
from piedemo.fields.outputs.colored_text import ptext, OutputColoredTextField
|
||||
from piedemo.fields.outputs.json import OutputJSONField
|
||||
from piedemo.fields.outputs.progress import ProgressField
|
||||
from piedemo.fields.outputs.video import OutputVideoField
|
||||
from piedemo.web import Web
|
||||
import os
|
||||
import io
|
||||
from piedemo.page import Page
|
||||
from piedemo.hub.svgpil import SVGImage
|
||||
from piedemo.fields.outputs.table import OutputTableField
|
||||
from piedemo.fields.inputs.int_list import InputIntListField
|
||||
from piedemo.fields.navigation import Navigation
|
||||
from piedemo.fields.inputs.chat import ChatField
|
||||
import librosa
|
||||
import uuid
|
||||
import numpy as np
|
||||
import redis
|
||||
import argparse
|
||||
from scipy.signal import savgol_filter
|
||||
import torch
|
||||
import random
|
||||
import os, subprocess
|
||||
import shlex
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
from aihandler import AIHandler
|
||||
from pieinfer import PieInfer, render_video, construct_video
|
||||
import torch
|
||||
from TTS.api import TTS
|
||||
|
||||
# Get device
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
|
||||
def get_asset(fname):
|
||||
return SVGImage.open(os.path.join(os.path.dirname(__file__),
|
||||
"assets",
|
||||
fname)).svg_content
|
||||
|
||||
|
||||
class MainPage(Page):
|
||||
def __init__(self, model_name: str):
|
||||
super(MainPage, self).__init__()
|
||||
self.infer = PieInfer()
|
||||
self.tts = TTS("tts_models/multilingual/multi-dataset/xtts_v2").to(device)
|
||||
|
||||
self.r = redis.Redis(host='localhost', port=6379, decode_responses=True)
|
||||
self.aihandler = AIHandler()
|
||||
|
||||
self.fields = Navigation(AjaxGroup("ChatGroup", VStack([
|
||||
HStack([
|
||||
AjaxChatField("Chat",
|
||||
self.register_ajax(f"/refresh_{model_name}",
|
||||
self.message_sent),
|
||||
deps_names=["sid",
|
||||
"session_id",
|
||||
"Chat",
|
||||
"Chat__piedemo__file"],
|
||||
use_socketio_support=True,
|
||||
nopie=True,
|
||||
style={
|
||||
"height": "100%"
|
||||
}),
|
||||
OutputColoredTextField("video",
|
||||
nopie=True,
|
||||
use_socketio_support=True),
|
||||
], xs=[8, 4]),
|
||||
ProgressField("progress",
|
||||
nopie=True,
|
||||
use_socketio_support=True),
|
||||
InputHiddenField("session_id", None),
|
||||
]), no_return=True), no_submit=True, page_title="MIA PIA", page_style={
|
||||
|
||||
})
|
||||
self.fields.add_link("SIMPLE",
|
||||
"/simple",
|
||||
active=model_name == "render")
|
||||
self.fields.add_link("MIA PIA",
|
||||
"/nice",
|
||||
active=model_name != "render")
|
||||
self.model_name = model_name
|
||||
|
||||
def get_content(self, **kwargs):
|
||||
fields = self.fields.copy()
|
||||
fields.child_loc["Chat"].set_default_options(["Hello! What is your name?", "Say one word and stop."])
|
||||
"""
|
||||
fields.child_loc["Chat"].set_avatars({
|
||||
"self": get_asset("avatar.svg"),
|
||||
"ChatGPT": get_asset("dog.svg"),
|
||||
})
|
||||
"""
|
||||
session_id = str(uuid.uuid4())
|
||||
return self.fill(fields, {
|
||||
"video": f"""
|
||||
""",
|
||||
"session_id": session_id,
|
||||
})
|
||||
|
||||
def message_sent(self, **data):
|
||||
sid = data['sid']
|
||||
self.emit(self.fields.child_loc["Chat"].clear_input(),
|
||||
to=sid)
|
||||
self.emit(self.fields.child_loc["video"].update(f"""
|
||||
"""))
|
||||
data = self.parse(self.fields, data)
|
||||
session_id = data['session_id']
|
||||
messages_map = self.r.hgetall(f'user-session:{session_id}')
|
||||
messages = [self.fields.child_loc["Chat"].format_message("self" if i % 2 == 0 else "ChatGPT",
|
||||
messages_map[f"message_{i}"])
|
||||
for i in range(len(messages_map))]
|
||||
|
||||
print("history: ", messages)
|
||||
|
||||
text = data['Chat']['text']
|
||||
|
||||
self.emit(self.fields.child_loc["Chat"].update(messages + [
|
||||
self.fields.child_loc["Chat"].format_message("self", text),
|
||||
self.fields.child_loc["Chat"].format_message("ChatGPT", "Generating text..."),
|
||||
]), to=sid)
|
||||
|
||||
output = self.aihandler(text)
|
||||
output_text = output['text']
|
||||
output_emotion = output['emotion']
|
||||
|
||||
messages_map[f"message_{len(messages)}"] = text
|
||||
messages_map[f"message_{len(messages) + 1}"] = output_text
|
||||
self.r.hset(f'user-session:{session_id}', mapping=messages_map)
|
||||
|
||||
self.emit(self.fields.child_loc["Chat"].update(messages + [
|
||||
self.fields.child_loc["Chat"].format_message("self", text),
|
||||
self.fields.child_loc["Chat"].format_message("ChatGPT", "Generating audio..."),
|
||||
]), to=sid)
|
||||
|
||||
self.tts.tts_to_file(text=output_text,
|
||||
speaker_wav="/home/ubuntu/repo/of_couse_here.wav",
|
||||
language="en",
|
||||
emotion=output_emotion,
|
||||
file_path=f"./audio/{session_id}.wav")
|
||||
speech_array, sampling_rate = librosa.load(f"./audio/{session_id}.wav",
|
||||
sr=16000)
|
||||
output = self.infer(speech_array, sampling_rate)
|
||||
np.save(os.path.join("./audio", "{}.npy".format(session_id)),
|
||||
output)
|
||||
|
||||
self.emit(self.fields.child_loc["Chat"].update(messages + [
|
||||
self.fields.child_loc["Chat"].format_message("self", text),
|
||||
self.fields.child_loc["Chat"].format_message("ChatGPT", "Rendering..."),
|
||||
]), to=sid)
|
||||
|
||||
n = output.shape[0]
|
||||
for i, fname in enumerate(tqdm(render_video(f"{session_id}",
|
||||
model_name=self.model_name),
|
||||
total=n)):
|
||||
print("Got frame: ", fname, file=sys.stderr)
|
||||
self.emit(self.fields.child_loc["progress"].update(100 * i // n),
|
||||
to=sid)
|
||||
construct_video(session_id)
|
||||
|
||||
self.emit(self.fields.child_loc["video"].update(f"""
|
||||
<video controls="1" autoplay="1" name="media" style="border-radius: 12px; height: 80%">
|
||||
<source src="/api/video/{session_id}" type="video/mp4">
|
||||
</video>
|
||||
"""), to=sid)
|
||||
|
||||
'''self.emit(self.fields.child_loc["video"].update(f"""
|
||||
<img name="media" style="border-radius: 12px; height: 80%" src="/api/video/stream/{session_id}"></img>
|
||||
"""))'''
|
||||
self.emit(self.fields.child_loc["Chat"].update(messages + [
|
||||
self.fields.child_loc["Chat"].format_message("self", text),
|
||||
self.fields.child_loc["Chat"].format_message("ChatGPT", output_text),
|
||||
]), to=sid)
|
||||
|
||||
|
||||
web = Web({
|
||||
"": "simple",
|
||||
"simple": MainPage("render"),
|
||||
"nice": MainPage("FemAdv_b350_V2_050523"),
|
||||
}, use_socketio_support=True)
|
||||
|
||||
|
||||
host = '0.0.0.0'
|
||||
port = 8011
|
||||
debug = False
|
||||
app = web.get_app()
|
||||
|
||||
|
||||
@app.route("/api/video/<session_id>", methods=["GET"])
|
||||
def get_video(session_id):
|
||||
return send_file("./audio/{}.mp4".format(session_id))
|
||||
|
||||
|
||||
def gen(session_id):
|
||||
for image_path in render_video(f"{session_id}"):
|
||||
with open(image_path, 'rb') as f:
|
||||
yield (b'--frame\r\n'
|
||||
b'Content-Type: image/jpeg\r\n\r\n' + f.read() + b'\r\n')
|
||||
construct_video(session_id)
|
||||
|
||||
|
||||
@app.route("/api/video/stream/<session_id>", methods=["GET"])
|
||||
def get_video_async(session_id):
|
||||
return Response(gen(session_id),
|
||||
mimetype='multipart/x-mixed-replace; boundary=frame')
|
||||
|
||||
|
||||
io = web.get_socketio(app)
|
||||
|
||||
|
||||
@io.on("io_set_text")
|
||||
def io_set_text(data):
|
||||
sid = None
|
||||
if "text" not in data:
|
||||
emit("io_error", {"message": "Text not found"},
|
||||
to=sid)
|
||||
|
||||
encode_string = base64.b64encode(open("../feeling_good.wav", "rb").read())
|
||||
for i in range(10):
|
||||
j = random.randint(0, 2)
|
||||
emit("io_set_coef", [{
|
||||
"index": j,
|
||||
"value": i / 10,
|
||||
}], to=sid)
|
||||
emit("io_push_audio_blob", {
|
||||
"dataURL": f"base64,{encode_string}"
|
||||
}, to=sid)
|
||||
emit("io_finish", {}, to=sid)
|
||||
|
||||
|
||||
io.run(app,
|
||||
host=host, port=port, debug=debug,
|
||||
allow_unsafe_werkzeug=True)
|
||||
Reference in New Issue
Block a user