572 lines
19 KiB
Python
572 lines
19 KiB
Python
import json
|
|
import sys
|
|
import re
|
|
from time import sleep, time
|
|
import logging
|
|
from collections import defaultdict
|
|
import pandas as pd
|
|
from flask import redirect
|
|
import argparse
|
|
import base64
|
|
|
|
from flask import send_file, Response, request, jsonify
|
|
from flask_socketio import emit
|
|
from piedemo.fields.ajax_group import AjaxChatField, AjaxGroup
|
|
from piedemo.fields.grid import VStack, HStack, SpaceField
|
|
from piedemo.fields.inputs.hidden import InputHiddenField
|
|
from piedemo.fields.outputs.colored_text import ptext, OutputColoredTextField
|
|
from piedemo.fields.outputs.json import OutputJSONField
|
|
from piedemo.fields.outputs.progress import ProgressField
|
|
from piedemo.fields.outputs.video import OutputVideoField
|
|
from piedemo.hub.swagger_utils.method import describe, check_missing_keys
|
|
from piedemo.web import Web
|
|
import os
|
|
import io
|
|
from piedemo.page import Page
|
|
from piedemo.hub.svgpil import SVGImage
|
|
from piedemo.fields.outputs.table import OutputTableField
|
|
from piedemo.fields.inputs.int_list import InputIntListField
|
|
from piedemo.fields.navigation import Navigation
|
|
from piedemo.fields.inputs.chat import ChatField
|
|
import librosa
|
|
import uuid
|
|
import numpy as np
|
|
import redis
|
|
import argparse
|
|
from scipy.signal import savgol_filter
|
|
import torch
|
|
import random
|
|
import os, subprocess
|
|
import shlex
|
|
import uuid
|
|
|
|
from tqdm import tqdm
|
|
|
|
from aihandler import AIHandler
|
|
from aihandler_stream import AIHandlerStream
|
|
from pieinfer import PieInfer, render_video, construct_video
|
|
import torch
|
|
from TTS.api import TTS
|
|
logging.getLogger('socketio').setLevel(logging.ERROR)
|
|
logging.getLogger('engineio').setLevel(logging.ERROR)
|
|
|
|
target_names = [
|
|
"mouthSmileLeft",
|
|
"mouthSmileRight",
|
|
"mouthStretchLeft",
|
|
"mouthStretchRight",
|
|
"mouthUpperUpLeft",
|
|
"mouthUpperUpRight",
|
|
]
|
|
|
|
model_bsList = ["browDownLeft",
|
|
"browDownRight",
|
|
"browInnerUp",
|
|
"browOuterUpLeft",
|
|
"browOuterUpRight",
|
|
"cheekPuff",
|
|
"cheekSquintLeft",
|
|
"cheekSquintRight",
|
|
"eyeBlinkLeft",
|
|
"eyeBlinkRight",
|
|
"eyeLookDownLeft",
|
|
"eyeLookDownRight",
|
|
"eyeLookInLeft",
|
|
"eyeLookInRight",
|
|
"eyeLookOutLeft",
|
|
"eyeLookOutRight",
|
|
"eyeLookUpLeft",
|
|
"eyeLookUpRight",
|
|
"eyeSquintLeft",
|
|
"eyeSquintRight",
|
|
"eyeWideLeft",
|
|
"eyeWideRight",
|
|
"jawForward",
|
|
"jawLeft",
|
|
"jawOpen",
|
|
"jawRight",
|
|
"mouthClose",
|
|
"mouthDimpleLeft",
|
|
"mouthDimpleRight",
|
|
"mouthFrownLeft",
|
|
"mouthFrownRight",
|
|
"mouthFunnel",
|
|
"mouthLeft",
|
|
"mouthLowerDownLeft",
|
|
"mouthLowerDownRight",
|
|
"mouthPressLeft",
|
|
"mouthPressRight",
|
|
"mouthPucker",
|
|
"mouthRight",
|
|
"mouthRollLower",
|
|
"mouthRollUpper",
|
|
"mouthShrugLower",
|
|
"mouthShrugUpper",
|
|
"mouthSmileLeft",
|
|
"mouthSmileRight",
|
|
"mouthStretchLeft",
|
|
"mouthStretchRight",
|
|
"mouthUpperUpLeft",
|
|
"mouthUpperUpRight",
|
|
"noseSneerLeft",
|
|
"noseSneerRight",
|
|
"tongueOut"]
|
|
|
|
|
|
# Get device
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
blendshapes_path = "./blendshapes"
|
|
|
|
|
|
def get_asset(fname):
|
|
return SVGImage.open(os.path.join(os.path.dirname(__file__),
|
|
"assets",
|
|
fname)).svg_content
|
|
|
|
|
|
class MainPage(Page):
|
|
def __init__(self, model_name: str):
|
|
super(MainPage, self).__init__()
|
|
self.infer = PieInfer()
|
|
self.tts = TTS("tts_models/multilingual/multi-dataset/xtts_v2").to(device)
|
|
|
|
self.r = redis.Redis(host='localhost', port=6379, decode_responses=True)
|
|
self.aihandler = AIHandler()
|
|
self.aihandler_stream = AIHandlerStream()
|
|
|
|
self.fields = Navigation(AjaxGroup("ChatGroup", VStack([
|
|
HStack([
|
|
AjaxChatField("Chat",
|
|
self.register_ajax(f"/refresh_{model_name}",
|
|
self.message_sent),
|
|
deps_names=["sid",
|
|
"session_id",
|
|
"Chat",
|
|
"Chat__piedemo__file"],
|
|
use_socketio_support=True,
|
|
nopie=True,
|
|
style={
|
|
"height": "100%"
|
|
}),
|
|
OutputColoredTextField("video",
|
|
nopie=True,
|
|
use_socketio_support=True),
|
|
], xs=[8, 4]),
|
|
ProgressField("progress",
|
|
nopie=True,
|
|
use_socketio_support=True),
|
|
InputHiddenField("session_id", None),
|
|
]), no_return=True), no_submit=True, page_title="MIA PIA", page_style={
|
|
|
|
})
|
|
self.fields.add_link("SIMPLE",
|
|
"/simple",
|
|
active=model_name == "render")
|
|
self.fields.add_link("MIA PIA",
|
|
"/nice",
|
|
active=model_name != "render")
|
|
self.model_name = model_name
|
|
|
|
def get_content(self, **kwargs):
|
|
fields = self.fields.copy()
|
|
fields.child_loc["Chat"].set_default_options(["Hello! What is your name?", "Say one word and stop."])
|
|
"""
|
|
fields.child_loc["Chat"].set_avatars({
|
|
"self": get_asset("avatar.svg"),
|
|
"ChatGPT": get_asset("dog.svg"),
|
|
})
|
|
"""
|
|
session_id = str(uuid.uuid4())
|
|
return self.fill(fields, {
|
|
"video": f"""
|
|
""",
|
|
"session_id": session_id,
|
|
})
|
|
|
|
def message_sent(self, **data):
|
|
sid = data['sid']
|
|
self.emit(self.fields.child_loc["Chat"].clear_input(),
|
|
to=sid)
|
|
self.emit(self.fields.child_loc["video"].update(f"""
|
|
"""))
|
|
data = self.parse(self.fields, data)
|
|
session_id = data['session_id']
|
|
messages_map = self.r.hgetall(f'user-session:{session_id}')
|
|
messages = [self.fields.child_loc["Chat"].format_message("self" if i % 2 == 0 else "ChatGPT",
|
|
messages_map[f"message_{i}"])
|
|
for i in range(len(messages_map))]
|
|
|
|
print("history: ", messages)
|
|
|
|
text = data['Chat']['text']
|
|
|
|
self.emit(self.fields.child_loc["Chat"].update(messages + [
|
|
self.fields.child_loc["Chat"].format_message("self", text),
|
|
self.fields.child_loc["Chat"].format_message("ChatGPT", "Generating text..."),
|
|
]), to=sid)
|
|
|
|
output = self.aihandler(text)
|
|
output_text = output['text']
|
|
output_emotion = output['emotion']
|
|
|
|
messages_map[f"message_{len(messages)}"] = text
|
|
messages_map[f"message_{len(messages) + 1}"] = output_text
|
|
self.r.hset(f'user-session:{session_id}', mapping=messages_map)
|
|
|
|
self.emit(self.fields.child_loc["Chat"].update(messages + [
|
|
self.fields.child_loc["Chat"].format_message("self", text),
|
|
self.fields.child_loc["Chat"].format_message("ChatGPT", "Generating audio..."),
|
|
]), to=sid)
|
|
|
|
self.tts.tts_to_file(text=output_text,
|
|
speaker_wav="/home/ubuntu/repo/of_couse_here.wav",
|
|
language="en",
|
|
emotion=output_emotion,
|
|
file_path=f"./audio/{session_id}.wav")
|
|
speech_array, sampling_rate = librosa.load(f"./audio/{session_id}.wav",
|
|
sr=16000)
|
|
output = self.infer(speech_array, sampling_rate)
|
|
np.save(os.path.join("./audio", "{}.npy".format(session_id)),
|
|
output)
|
|
|
|
self.emit(self.fields.child_loc["Chat"].update(messages + [
|
|
self.fields.child_loc["Chat"].format_message("self", text),
|
|
self.fields.child_loc["Chat"].format_message("ChatGPT", "Rendering..."),
|
|
]), to=sid)
|
|
|
|
n = output.shape[0]
|
|
for i, fname in enumerate(tqdm(render_video(f"{session_id}",
|
|
model_name=self.model_name),
|
|
total=n)):
|
|
print("Got frame: ", fname, file=sys.stderr)
|
|
self.emit(self.fields.child_loc["progress"].update(100 * i // n),
|
|
to=sid)
|
|
construct_video(session_id)
|
|
|
|
self.emit(self.fields.child_loc["video"].update(f"""
|
|
<video controls="1" autoplay="1" name="media" style="border-radius: 12px; height: 80%">
|
|
<source src="/api/video/{session_id}" type="video/mp4">
|
|
</video>
|
|
"""), to=sid)
|
|
|
|
'''self.emit(self.fields.child_loc["video"].update(f"""
|
|
<img name="media" style="border-radius: 12px; height: 80%" src="/api/video/stream/{session_id}"></img>
|
|
"""))'''
|
|
self.emit(self.fields.child_loc["Chat"].update(messages + [
|
|
self.fields.child_loc["Chat"].format_message("self", text),
|
|
self.fields.child_loc["Chat"].format_message("ChatGPT", output_text),
|
|
]), to=sid)
|
|
|
|
|
|
page = MainPage("render")
|
|
web = Web({
|
|
"": "simple",
|
|
"simple": page,
|
|
"nice": page,
|
|
}, use_socketio_support=True)
|
|
|
|
|
|
host = '0.0.0.0'
|
|
port = 8011
|
|
debug = False
|
|
app = web.get_app()
|
|
|
|
|
|
@app.route("/api/video/<session_id>", methods=["GET"])
|
|
def get_video(session_id):
|
|
return send_file("./audio/{}.mp4".format(session_id))
|
|
|
|
|
|
def gen(session_id):
|
|
for image_path in render_video(f"{session_id}"):
|
|
with open(image_path, 'rb') as f:
|
|
yield (b'--frame\r\n'
|
|
b'Content-Type: image/jpeg\r\n\r\n' + f.read() + b'\r\n')
|
|
construct_video(session_id)
|
|
|
|
|
|
@app.route("/api/video/stream/<session_id>", methods=["GET"])
|
|
def get_video_async(session_id):
|
|
return Response(gen(session_id),
|
|
mimetype='multipart/x-mixed-replace; boundary=frame')
|
|
|
|
|
|
speaker_path = "/home/ubuntu/repo/female.wav"
|
|
|
|
|
|
@app.route("/api/set_speaker", methods=["POST"])
|
|
@describe(["3dmodel"],
|
|
name="Set emotion for 3D model",
|
|
description="""Set speaker for 3D model""",
|
|
inputs={
|
|
"user_id": "This ID from article Unique Identifier for iPHONE",
|
|
"speaker": "voice1 or voice2"
|
|
},
|
|
outputs={
|
|
"status": "ok"
|
|
})
|
|
@check_missing_keys([
|
|
("user_id", {"status": "error", "status_code": "missing_user_id_error"}),
|
|
("speaker", {"status": "error", "status_code": "missing_emotion_error"}),
|
|
])
|
|
def set_speaker():
|
|
speaker = request.json.get("speaker")
|
|
user_id = request.json.get("user_id")
|
|
SPEAKER[user_id] = speaker
|
|
return jsonify({
|
|
'status': 'ok'
|
|
})
|
|
|
|
|
|
@app.route("/api/set_emotion", methods=["POST"])
|
|
@describe(["3dmodel"],
|
|
name="Set emotion for 3D model",
|
|
description="""Set emotion for 3D model""",
|
|
inputs={
|
|
"user_id": "This ID from article Unique Identifier for iPHONE",
|
|
"emotion": "sad"
|
|
},
|
|
outputs={
|
|
"status": "ok"
|
|
})
|
|
@check_missing_keys([
|
|
("user_id", {"status": "error", "status_code": "missing_user_id_error"}),
|
|
("emotion", {"status": "error", "status_code": "missing_emotion_error"}),
|
|
])
|
|
def set_emotion():
|
|
emotion = request.json.get("emotion")
|
|
user_id = request.json.get("user_id")
|
|
EMOTIONS[user_id] = emotion
|
|
return jsonify({
|
|
'status': 'ok'
|
|
})
|
|
|
|
|
|
@app.route("/api/get_texts", methods=["POST"])
|
|
@describe(["text"],
|
|
name="Get texts for user_id",
|
|
description="""This endpoint get all texts for current iPhone""",
|
|
inputs={
|
|
"user_id": "This ID from article Unique Identifier for iPHONE"
|
|
},
|
|
outputs={
|
|
"text": "Output",
|
|
"id": "bot or user",
|
|
})
|
|
@check_missing_keys([
|
|
("user_id", {"status": "error", "status_code": "missing_user_id_error"}),
|
|
])
|
|
def get_texts():
|
|
user_id = request.json.get("user_id")
|
|
return jsonify(TEXTS[user_id])
|
|
|
|
|
|
@app.route("/api/send_text", methods=["POST"])
|
|
@describe(["text"],
|
|
name="Sent text to miapia",
|
|
description="""This endpoint sends texts for client""",
|
|
inputs={
|
|
"text": "Hello, MIAPIA",
|
|
"user_id": "This ID from article Unique Identifier for iPHONE"
|
|
},
|
|
outputs={
|
|
"status": "ok"
|
|
})
|
|
@check_missing_keys([
|
|
("text", {"status": "error", "status_code": "missing_text_error"}),
|
|
("user_id", {"status": "error", "status_code": "missing_user_id_error"}),
|
|
])
|
|
def send_text():
|
|
user_id = request.json.get("user_id")
|
|
text = request.json.get("text", "")
|
|
TEXTS[user_id].append({
|
|
"id": 'user',
|
|
"text": text
|
|
})
|
|
output_texts = page.aihandler_stream(text)
|
|
bot_text = ""
|
|
for output_text in output_texts:
|
|
bot_text += " " + output_text
|
|
TEXTS[user_id].append({
|
|
"id": 'bot',
|
|
"text": bot_text
|
|
})
|
|
return jsonify({
|
|
"status": "ok",
|
|
"messages": TEXTS[user_id]
|
|
})
|
|
|
|
|
|
io = web.get_socketio(app,
|
|
engineio_logger=False)
|
|
head_memories = {}
|
|
TEXTS = defaultdict(list)
|
|
EMOTIONS = {}
|
|
SPEAKER = {}
|
|
|
|
|
|
def get_event(name, value, timestamp):
|
|
return {
|
|
"index": model_bsList.index(name),
|
|
"value": value,
|
|
"timestamp": timestamp
|
|
}
|
|
|
|
|
|
def get_value(events, name):
|
|
index = model_bsList.index(name)
|
|
events = [event for event in events
|
|
if event['index'] == index]
|
|
if len(events) == 0:
|
|
return None
|
|
return events[-1]['value']
|
|
|
|
|
|
def get_head_memory():
|
|
ids = [100, 101, 103, 104, 106, 107, 109, 110]
|
|
return [[0, 0, 1] for _ in range(len(ids))]
|
|
|
|
|
|
def get_head_rotations(alpha, duration, memory, sign):
|
|
ids = [100, 101, 103, 104, 106, 107, 109, 110]
|
|
for _ in range(3):
|
|
index = ids.index(random.choice(ids))
|
|
step = 0.01 * sign[index]
|
|
memory[index][0] += step
|
|
memory[index][0] = min(memory[index][0], memory[index][2])
|
|
memory[index][0] = max(memory[index][0], memory[index][1])
|
|
print(memory)
|
|
return [{
|
|
"index": j,
|
|
"value": memory[i][0],
|
|
"timestamp": float(duration * alpha)
|
|
} for i, j in enumerate(ids)], memory
|
|
|
|
|
|
def perform_on_text(output_text, sid, head_memory, sign, voice):
|
|
session_id = str(uuid.uuid4())
|
|
page.tts.tts_to_file(text=output_text,
|
|
speaker_wav="/home/ubuntu/repo/female.wav" if voice == "voice1" else "/home/ubuntu/repo/indian.wav",
|
|
language="en",
|
|
emotion="Happy",
|
|
file_path=f"./audio/{session_id}.wav")
|
|
|
|
audio_path = f"./audio/{session_id}.wav"
|
|
with open(audio_path, 'rb') as f:
|
|
audio_content = f.read()
|
|
|
|
encode_string = base64.b64encode(audio_content).decode('utf-8')
|
|
speech_array, sampling_rate = librosa.load(audio_path,
|
|
sr=16000)
|
|
duration = librosa.get_duration(y=speech_array,
|
|
sr=sampling_rate)
|
|
output = page.infer(speech_array, sampling_rate)
|
|
emit("io_push_audio_blob", {
|
|
"dataURL": f"base64,{encode_string}"
|
|
}, to=sid)
|
|
print("Sent audio.")
|
|
emit("io_set_size", {
|
|
"size": output.shape[0],
|
|
}, to=sid)
|
|
t1 = time()
|
|
for i in tqdm(range(output.shape[0])):
|
|
rots, head_memory = get_head_rotations((i / output.shape[0]), duration, head_memory, sign)
|
|
blendshapes_i = [{
|
|
"index": j,
|
|
"value": output[i, j],
|
|
"timestamp": float(duration * (i / output.shape[0]))
|
|
} for j in range(output.shape[1])] + rots
|
|
if max([get_value(blendshapes_i, target_name)
|
|
for target_name in target_names]) > 0.5:
|
|
os.makedirs(blendshapes_path,
|
|
exist_ok=True)
|
|
save_blendshapes_i = os.path.join(blendshapes_path,
|
|
str(uuid.uuid4()) + '.json')
|
|
with open(save_blendshapes_i, 'w') as f:
|
|
json.dump(blendshapes_i, f)
|
|
emit("io_set_coef", blendshapes_i, to=sid)
|
|
# sleep(0.1 * duration / output.shape[0])
|
|
t2 = time()
|
|
sleep(max(0., duration - (t2 - t1)))
|
|
return head_memory
|
|
|
|
|
|
def perform_surgery(sid, duration=5):
|
|
with open("../5-seconds-of-silence.wav", 'rb') as f:
|
|
audio_content = f.read()
|
|
encode_string = base64.b64encode(audio_content).decode('utf-8')
|
|
fps = 20
|
|
emit("io_push_audio_blob", {
|
|
"dataURL": f"base64,{encode_string}"
|
|
}, to=sid)
|
|
print("Sent audio.")
|
|
emit("io_set_size", {
|
|
"size": (fps * duration)
|
|
}, to=sid)
|
|
t1 = time()
|
|
for i in tqdm(range(fps * duration)):
|
|
alpha = float(i / (fps * duration))
|
|
emit("io_set_coef", [
|
|
get_event("eyeWideLeft",
|
|
0.3 - 0.3 * alpha,
|
|
float(duration * alpha)),
|
|
get_event("eyeWideRight",
|
|
0.3 - 0.3 * alpha,
|
|
float(duration * alpha))
|
|
], to=sid)
|
|
t2 = time()
|
|
sleep(max(0., duration - (t2 - t1)))
|
|
|
|
|
|
@io.on("io_set_text")
|
|
def io_set_text(data):
|
|
data = json.loads(data)
|
|
data = data[0]
|
|
sid = None
|
|
print(data, file=sys.stderr)
|
|
if "text" not in data:
|
|
emit("io_error", {"message": "Text not found"},
|
|
to=sid)
|
|
return
|
|
|
|
text = data["text"]
|
|
|
|
"""if "user_id" not in data:
|
|
emit("io_error", {"message": "User not found"},
|
|
to=sid)
|
|
return"""
|
|
user_id = data.get('user_id')
|
|
print(user_id)
|
|
TEXTS[user_id].append({
|
|
"id": "user",
|
|
"text": text
|
|
})
|
|
voice = SPEAKER.get(user_id, "voice1")
|
|
|
|
if sid not in head_memories:
|
|
head_memories[sid] = get_head_memory()
|
|
head_memory = head_memories[sid]
|
|
# output_texts = [page.aihandler(text)['text']]
|
|
output_texts = page.aihandler_stream(text)
|
|
bot_text = ""
|
|
for output_text in output_texts:
|
|
sign = [2 * (random.random() > 0.5) - 1
|
|
for _ in range(8)]
|
|
head_memory = perform_on_text(output_text, sid, head_memory,
|
|
sign=sign,
|
|
voice=voice)
|
|
bot_text += " " + output_text
|
|
print("SURGERY STARTED!")
|
|
# perform_surgery(sid)
|
|
print("SURGERY ENDED!")
|
|
TEXTS[user_id].append({
|
|
"id": "bot",
|
|
"text": bot_text
|
|
})
|
|
emit("io_finish", {}, to=sid)
|
|
|
|
|
|
io.run(app,
|
|
host=host, port=port, debug=debug,
|
|
allow_unsafe_werkzeug=True)
|