uma-tts-api/app.py
2025-11-23 21:39:29 +09:00

203 lines
6.2 KiB
Python

import torch
from torch import LongTensor, no_grad
import soundfile as sf
import commons
from models import utils
from io import BytesIO
from models.models import SynthesizerTrn
from text import text_to_sequence
from text.symbols import symbols
from flask import Flask, request, jsonify, send_file
import threading
from constant import speakerList
import logging
import imageio_ffmpeg
from pydub import AudioSegment
app = Flask(__name__)
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s | %(levelname)s | %(module)s | %(message)s',
datefmt='%Y-%m-%d %H:%M:%S',
handlers=[
logging.FileHandler("app.log"), # Log to file
logging.StreamHandler() # Log to console
]
)
# Ensure Flask logs are redirected
werkzeug_logger = logging.getLogger('werkzeug')
werkzeug_logger.setLevel(logging.INFO)
werkzeug_logger.addHandler(logging.FileHandler("app.log"))
logging.getLogger('pydub.converter').setLevel(logging.WARNING)
logger = logging.getLogger(__name__)
logger.info("Logging is configured and ready to use.")
# Load the model and hyperparameters when the server starts
config_path = "configs/uma.json"
checkpoint_path = "G_790000.pth"
# Speaker mapping
umas = speakerList
# Create a mapping from speaker names to IDs
umas_name_to_id = {name: int(idx) for idx, name in umas.items()}
# Define a lock for thread-safe model inference
model_lock = threading.Lock()
# Set FFmpeg Path for Pydub
AudioSegment.converter = imageio_ffmpeg.get_ffmpeg_exe()
def get_text(text, hps):
"""
Converts input text into a tensor format suitable for the synthesizer model.
"""
# Normalize text into sequence
text_norm = text_to_sequence(text, hps.data.text_cleaners)
if hps.data.add_blank:
text_norm = commons.intersperse(text_norm, 0)
return LongTensor(text_norm)
def load_model(config_path, checkpoint_path):
"""
Loads the TTS model from the given configuration and checkpoint files.
"""
# Load hyperparameters from JSON config
hps = utils.get_hparams_from_file(config_path)
# Initialize the VITS model
net_g = SynthesizerTrn(
n_vocab=len(symbols),
spec_channels=hps.data.filter_length // 2 + 1,
segment_size=hps.train.segment_size // hps.data.hop_length,
n_speakers=hps.data.n_speakers,
**hps.model
)
net_g.eval()
# Move model to GPU if available
if torch.cuda.is_available():
net_g.cuda()
else:
logger.warning("CUDA is not available. Running on CPU.")
# Load model checkpoint
utils.load_checkpoint(checkpoint_path, net_g, None)
logger.info(f"Model checkpoint '{checkpoint_path}' loaded successfully!")
return net_g, hps
def synthesize_speech(net_g, hps, text, speaker_id, noise_scale, noise_scale_w, length_scale):
"""
Generates speech audio from text using the TTS model.
"""
stn_tst = get_text(text, hps)
x_tst = stn_tst.unsqueeze(0)
x_tst_lengths = LongTensor([stn_tst.size(0)])
sid = LongTensor([speaker_id]) # Ensure speaker_id is an integer
if torch.cuda.is_available():
x_tst = x_tst.cuda()
x_tst_lengths = x_tst_lengths.cuda()
sid = sid.cuda()
with no_grad():
audio = net_g.infer(
x_tst,
x_tst_lengths,
sid=sid,
noise_scale=noise_scale,
noise_scale_w=noise_scale_w,
length_scale=length_scale
)[0][0,0].data.cpu().float().numpy()
return audio
# Loading the models on server boot
net_g, hps = load_model(config_path, checkpoint_path)
@app.route('/synthesize', methods=['POST'])
def synthesize():
data = request.get_json()
# Log all request parameters
logger.info(f"Request parameters: {data}")
# Extract parameters with default values
Speaker_Uma = data.get('speaker_name', 'Rice Shower')
Japanese_text = data.get('text', 'おにー様、すきです')
noise_scale = float(data.get('noise_scale', 0.37))
noise_scale_w = float(data.get('noise_scale_w', 0.46))
length_scale = float(data.get('length_scale', 1.3))
# Get speaker ID from name
speaker_id = umas_name_to_id.get(Speaker_Uma)
if speaker_id is None:
logger.error(f"Speaker '{Speaker_Uma}' not found.")
return jsonify({'error': f"Speaker '{Speaker_Uma}' not found."}), 400
# Generate speech audio
logger.info(f"Generating synthesis for speaker: {Speaker_Uma}")
with model_lock:
audio = synthesize_speech(
net_g, hps, Japanese_text, speaker_id, noise_scale, noise_scale_w, length_scale
)
# Save the audio as WAV in memory
wav_buffer = BytesIO()
sf.write(wav_buffer, audio, hps.data.sampling_rate, format='WAV')
wav_buffer.seek(0)
# Convert WAV to AAC (m4a format) using pydub
wav_audio = AudioSegment.from_file(wav_buffer, format="wav")
aac_buffer = BytesIO()
wav_audio.export(aac_buffer, format="ipod")
aac_buffer.seek(0)
logger.info("Audio synthesis completed successfully.")
logger.info("Sending synthesized audio file")
# Return the audio file
return send_file(
aac_buffer,
mimetype='audio/aac',
as_attachment=True,
download_name='output.aac'
)
@app.route('/speakers', methods=['POST'])
def get_speakers():
"""
API endpoint to retrieve a list of speakers.
Expects a JSON payload with an optional 'search' parameter.
"""
# Parse the JSON payload
data = request.get_json()
# Log all request parameters
logger.info(f"Request parameters: {data}")
search = data.get('search', '') if data else ''
# Filter speakers based on the search term
speaker_list = [name for name in umas.values() if not search or search.lower() in name.lower()]
# Return the filtered list as JSON
logger.info(f"Sending speaker list for search term: '{search}' with {len(speaker_list)} results.")
return jsonify(speaker_list)
@app.route('/health', methods=['GET'])
def health_check():
"""
Health check endpoint to verify that the server is running.
"""
logger.info("Health check requested.")
return jsonify({'status': 'ok'})
if __name__ == '__main__':
app.run(host='0.0.0.0', port=18343)